From c706ff20a328e3cb0c7687d79091a9a0df25d93d Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Mon, 6 Apr 2020 22:36:01 +0800 Subject: [PATCH] fix conflict, test=develop (#23298) --- paddle/fluid/operators/rank_attention.cu.h | 153 ++++++++++++ paddle/fluid/operators/rank_attention_op.cc | 151 ++++++++++++ paddle/fluid/operators/rank_attention_op.cu | 215 +++++++++++++++++ paddle/fluid/operators/rank_attention_op.h | 32 +++ python/paddle/fluid/contrib/layers/nn.py | 64 ++++- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_layers.py | 16 ++ .../tests/unittests/test_rank_attention_op.py | 221 ++++++++++++++++++ 8 files changed, 852 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/rank_attention.cu.h create mode 100644 paddle/fluid/operators/rank_attention_op.cc create mode 100644 paddle/fluid/operators/rank_attention_op.cu create mode 100644 paddle/fluid/operators/rank_attention_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_rank_attention_op.py diff --git a/paddle/fluid/operators/rank_attention.cu.h b/paddle/fluid/operators/rank_attention.cu.h new file mode 100644 index 00000000000..9de3de241dc --- /dev/null +++ b/paddle/fluid/operators/rank_attention.cu.h @@ -0,0 +1,153 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/dim.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void expand_input_by_rank_kernel( + const T* input, int input_row, int input_col, T* output, int output_row, + int output_col, const int* rank_offset, int rank_offset_row, + int rank_offset_col, T* ins_rank, int max_rank) { + CUDA_KERNEL_LOOP(idx, output_row * output_col) { + int output_col_idx = idx % output_col; + int output_row_idx = idx / output_col; + int k = output_col_idx / input_col; + + int faster = rank_offset[output_row_idx * rank_offset_col + 2 * k + 1] - 1; + if (output_col_idx == 0) { + ins_rank[output_row_idx] = rank_offset[output_row_idx * rank_offset_col]; + } + + if (rank_offset[output_row_idx * rank_offset_col] - 1 < 0 || faster < 0) { + continue; + } + + int rank_input_col_idx = output_col_idx % input_col; + int index = rank_offset[output_row_idx * rank_offset_col + 2 * k + 2]; + output[idx] = input[rank_input_col_idx + index * input_col]; + } +} + +template +void expand_rank_attention_input(cudaStream_t stream, const T* input, + int input_row, int input_col, T* output, + int output_row, int output_col, + const int* rank_offset, int rank_offset_row, + int rank_offset_col, T* ins_rank, + int max_rank) { + expand_input_by_rank_kernel<<>>( + input, input_row, input_col, output, output_row, output_col, rank_offset, + rank_offset_row, rank_offset_col, ins_rank, max_rank); +} + +template +__global__ void expand_rank_attention_param_kernel( + const T* input, int input_row, int input_col, const int* rank_offset, + int rank_offset_row, int rank_offset_col, const T* param, int param_row, + int param_col, T* output_param, int output_param_row, int output_param_col, + int max_rank) { + CUDA_KERNEL_LOOP(idx, output_param_row * output_param_col) { + int output_col_idx = idx % output_param_col; + int output_row_idx = idx / output_param_col; + + int block_matrix_row = max_rank * input_col; + int ins_idx = output_row_idx / block_matrix_row; + int start_offset = output_row_idx % block_matrix_row; + + int k = start_offset / input_col; + int k_offset = start_offset % input_col; + + int lower = rank_offset[ins_idx * rank_offset_col] - 1; + int faster = rank_offset[2 * k + 1 + rank_offset_col * ins_idx] - 1; + + if (lower < 0 || faster < 0) { + continue; + } + int start = lower * max_rank + faster; + int ori_idx = + start * param_col * input_col + k_offset * param_col + output_col_idx; + output_param[idx] = param[ori_idx]; + } +} + +template +void expand_rank_attention_param(cudaStream_t stream, const T* input, + int input_row, int input_col, + const int* rank_offset, int rank_offset_row, + int rank_offset_col, const T* param, + int param_row, int param_col, T* output_param, + int output_param_row, int output_param_col, + int max_rank) { + expand_rank_attention_param_kernel<<>>( + input, input_row, input_col, rank_offset, rank_offset_row, + rank_offset_col, param, param_row, param_col, output_param, + output_param_row, output_param_col, max_rank); +} + +template +__global__ void merge_param_gradient_kernel( + T* expanded_grad, int expanded_grad_row, int expanded_grad_col, + T* param_grad, int param_grad_row, int param_grad_col, const T* ins_rank, + int ins_num, int max_rank, int input_col) { + CUDA_KERNEL_LOOP(tid, param_grad_row * param_grad_col) { + int param_col_idx = tid % param_grad_col; + int param_row_idx = tid / param_grad_col; + + int block_matrix_row = max_rank * input_col; + int rank_idx = param_row_idx / block_matrix_row; + int rank_offset = param_row_idx % block_matrix_row; + + T tmp = 0; + for (int i = 0; i < ins_num; ++i) { + if (ins_rank[i] == rank_idx + 1) { + int row = i * block_matrix_row + rank_offset; + tmp += expanded_grad[row * expanded_grad_col + param_col_idx]; + } + } + param_grad[tid] = tmp; + } +} + +template +void merge_rank_attention_param_grad(cudaStream_t stream, T* expanded_grad, + int expanded_grad_row, + int expanded_grad_col, T* param_grad, + int param_grad_row, int param_grad_col, + const T* ins_rank, int ins_num, + int max_rank, int input_col) { + merge_param_gradient_kernel<<>>( + expanded_grad, expanded_grad_row, expanded_grad_col, param_grad, + param_grad_row, param_grad_col, ins_rank, ins_num, max_rank, input_col); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/rank_attention_op.cc b/paddle/fluid/operators/rank_attention_op.cc new file mode 100644 index 00000000000..76a04014e4e --- /dev/null +++ b/paddle/fluid/operators/rank_attention_op.cc @@ -0,0 +1,151 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/rank_attention_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +class RankAttentionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of RankAttentionOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("RankOffset"), true, + platform::errors::InvalidArgument( + "Input(RankOffset) of RankAttentionOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("RankParam"), true, + platform::errors::InvalidArgument( + "Input(RankParam) of RankAttentionOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of RankAttentionOp should not be null.")); + auto max_rank = ctx->Attrs().Get("MaxRank"); + + auto x_dims = ctx->GetInputDim("X"); + auto ins_num = x_dims[0]; + auto param_dims = ctx->GetInputDim("RankParam"); + auto para_col = param_dims[1]; + auto rank_offset_dims = ctx->GetInputDim("RankOffset"); + + PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank, + platform::errors::InvalidArgument( + "Input(RankOffset) has wrong columns.")); + + ctx->SetOutputDim("Out", {ins_num, para_col}); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class RankAttentionGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::InvalidArgument("Input(X) should not be null")); + PADDLE_ENFORCE_EQ(ctx->HasInput("RankParam"), true, + platform::errors::InvalidArgument( + "Input(RankParam) should not be null")); + PADDLE_ENFORCE_EQ(ctx->HasInput("RankOffset"), true, + platform::errors::InvalidArgument( + "Input(RankOffset) should not be null")); + + ctx->SetOutputDim(framework::GradVarName("RankParam"), + ctx->GetInputDim("RankParam")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class RankAttentionOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input tensor of rank_attention_Op operator."); + AddInput("RankOffset", + "(Tensor) Input tensor of rank_attention_Op operator."); + AddInput("RankParam", + "(Tensor) Input tensor of rank_attention_Op operator."); + AddOutput("Out", "Output tensor of rank_attention_Op operator."); + AddAttr("MaxRank", "(int, default 3) max rank of rank_attention_Op") + .SetDefault(3); + AddComment(R"DOC( +RankAttention Operator. +This Op can calculate rank attention between input and rank_param, +and rank_param gives the organization of data. Notice: It currently supports GPU device. +This Op exists in contrib, which means that it is not shown to the public. +)DOC"); + } +}; + +template +class RankAttentionGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("rank_attention_grad"); + + op->SetInput("X", this->Input("X")); + op->SetInput("RankOffset", this->Input("RankOffset")); + op->SetInput("RankParam", this->Input("RankParam")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + op->SetOutput(framework::GradVarName("RankParam"), + this->InputGrad("RankParam")); + op->SetAttrMap(this->Attrs()); + } +}; +DECLARE_NO_NEED_BUFFER_VARS_INFERER( + RankAttentionGradOpNoNeedBufferVarsInference, "RankParam"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(rank_attention, ops::RankAttentionOp, + ops::RankAttentionOpMaker, + ops::RankAttentionGradOpMaker, + ops::RankAttentionGradOpMaker); + +REGISTER_OPERATOR(rank_attention_grad, ops::RankAttentionGradOp, + ops::RankAttentionGradOpNoNeedBufferVarsInference); + +REGISTER_OP_CPU_KERNEL( + rank_attention, + ops::RankAttentionKernel, + ops::RankAttentionKernel); diff --git a/paddle/fluid/operators/rank_attention_op.cu b/paddle/fluid/operators/rank_attention_op.cu new file mode 100644 index 00000000000..08e2a9ccca4 --- /dev/null +++ b/paddle/fluid/operators/rank_attention_op.cu @@ -0,0 +1,215 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/rank_attention.cu.h" +#include "paddle/fluid/operators/rank_attention_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +class RankAttentionCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *rank_offset = ctx.Input("RankOffset"); + auto *param = ctx.Input("RankParam"); + int max_rank = ctx.Attr("MaxRank"); + auto *Out = ctx.Output("Out"); + + // check dims + auto x_dims = X->dims(); + auto ins_num = x_dims[0]; + auto x_fea_dim = x_dims[1]; + auto para_dims = param->dims(); + auto para_row = para_dims[0]; + auto para_col = para_dims[1]; + auto rank_offset_dims = rank_offset->dims(); + PADDLE_ENFORCE_EQ( + rank_offset_dims[0], ins_num, + platform::errors::InvalidArgument("Input(RankOffset) has wrong rows.")); + PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank, + platform::errors::InvalidArgument( + "Input(RankOffset) has wrong columns.")); + PADDLE_ENFORCE_EQ( + max_rank * max_rank * x_fea_dim, para_row, + platform::errors::InvalidArgument("Input(RankParam) has wrong rows.")); + + int block_matrix_row = max_rank * x_fea_dim; + + auto &dev_ctx = ctx.template device_context(); + auto stream = ctx.cuda_device_context().stream(); + int device_id = platform::GetCurrentDeviceId(); + + T *param_help_data; + auto param_help_size = ins_num * block_matrix_row * para_col * sizeof(T); + platform::RecordedCudaMalloc(reinterpret_cast(¶m_help_data), + param_help_size, device_id); + platform::GpuMemsetAsync(param_help_data, 0, param_help_size, stream); + + T *input_help_data; + auto input_help_size = ins_num * block_matrix_row * sizeof(T); + platform::RecordedCudaMalloc(reinterpret_cast(&input_help_data), + input_help_size, device_id); + platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream); + + T *ins_rank_data; + auto ins_rank_size = ins_num * sizeof(T); + platform::RecordedCudaMalloc(reinterpret_cast(&ins_rank_data), + ins_rank_size, device_id); + platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream); + + Out->mutable_data(ctx.GetPlace()); + + // initialize + auto out_eigen = framework::EigenVector::Flatten(*Out); + auto &place = *ctx.template device_context() + .eigen_device(); + out_eigen.device(place) = out_eigen.constant(static_cast(0)); + + // get data ptr + T *out_data = Out->data(); + expand_rank_attention_input( + ctx.cuda_device_context().stream(), X->data(), ins_num, x_fea_dim, + input_help_data, ins_num, block_matrix_row, rank_offset->data(), + rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank); + + expand_rank_attention_param( + ctx.cuda_device_context().stream(), X->data(), ins_num, x_fea_dim, + rank_offset->data(), rank_offset_dims[0], rank_offset_dims[1], + param->data(), para_row, para_col, param_help_data, + ins_num * block_matrix_row, para_col, max_rank); + + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + + T alpha = 1; + T beta = 0; + int64_t strideA = block_matrix_row; + int64_t strideB = block_matrix_row * para_col; + + auto blas = math::GetBlas(dev_ctx); + blas.BatchedGEMM(transA, transB, 1, para_col, block_matrix_row, alpha, + input_help_data, param_help_data, beta, out_data, ins_num, + strideA, strideB); + + platform::RecordedCudaFree(param_help_data, param_help_size, device_id); + platform::RecordedCudaFree(input_help_data, input_help_size, device_id); + platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id); + } +}; + +template +class RankAttentionGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *rank_offset = ctx.Input("RankOffset"); + auto *param = ctx.Input("RankParam"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + + auto *drank_para = ctx.Output(framework::GradVarName("RankParam")); + + // get dim + auto x_dims = X->dims(); + auto ins_num = x_dims[0]; + auto x_fea_dim = x_dims[1]; + auto para_dims = param->dims(); + auto para_row = para_dims[0]; + auto para_col = para_dims[1]; + auto rank_offset_dims = rank_offset->dims(); + auto max_rank = (rank_offset_dims[1] - 1) / 2; + int block_matrix_row = max_rank * x_fea_dim; + auto &dev_ctx = ctx.template device_context(); + auto &place = *ctx.template device_context() + .eigen_device(); + + // initialize out grad + drank_para->mutable_data(ctx.GetPlace()); + auto drank_para_eigen = framework::EigenVector::Flatten(*drank_para); + drank_para_eigen.device(place) = + drank_para_eigen.constant(static_cast(0)); + + auto stream = ctx.cuda_device_context().stream(); + int device_id = platform::GetCurrentDeviceId(); + + T *param_grad_data; + auto param_grad_size = ins_num * block_matrix_row * para_col * sizeof(T); + platform::RecordedCudaMalloc(reinterpret_cast(¶m_grad_data), + param_grad_size, device_id); + platform::GpuMemsetAsync(param_grad_data, 0, param_grad_size, stream); + + T *input_help_data; + auto input_help_size = ins_num * block_matrix_row * sizeof(T); + platform::RecordedCudaMalloc(reinterpret_cast(&input_help_data), + input_help_size, device_id); + platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream); + + T *ins_rank_data; + auto ins_rank_size = ins_num * sizeof(T); + platform::RecordedCudaMalloc(reinterpret_cast(&ins_rank_data), + ins_rank_size, device_id); + platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream); + + // expand input + expand_rank_attention_input( + ctx.cuda_device_context().stream(), X->data(), ins_num, x_fea_dim, + input_help_data, ins_num, block_matrix_row, rank_offset->data(), + rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank); + + auto blas = math::GetBlas(dev_ctx); + T alpha = 1; + T beta = 0; + + // get param_grad + CBLAS_TRANSPOSE transA = CblasTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + int64_t strideA = block_matrix_row; + int64_t strideB = para_col; + + blas.BatchedGEMM(transA, transB, block_matrix_row, para_col, 1, alpha, + input_help_data, dout->data(), beta, param_grad_data, + ins_num, strideA, strideB); + + // merge param_grad to get drank_para + merge_rank_attention_param_grad( + ctx.cuda_device_context().stream(), param_grad_data, + ins_num * block_matrix_row, para_col, drank_para->data(), para_row, + para_col, ins_rank_data, ins_num, max_rank, x_fea_dim); + + platform::RecordedCudaFree(param_grad_data, param_grad_size, device_id); + platform::RecordedCudaFree(input_help_data, input_help_size, device_id); + platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using GPUCtx = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(rank_attention, + ops::RankAttentionCUDAKernel, + ops::RankAttentionCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(rank_attention_grad, + ops::RankAttentionGradOpCUDAKernel, + ops::RankAttentionGradOpCUDAKernel); diff --git a/paddle/fluid/operators/rank_attention_op.h b/paddle/fluid/operators/rank_attention_op.h new file mode 100644 index 00000000000..796546e61f1 --- /dev/null +++ b/paddle/fluid/operators/rank_attention_op.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class RankAttentionKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::Unimplemented( + "Rank Attention only supports GPU now.")); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index a8cebfa579c..e6509bc4a1a 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -33,7 +33,7 @@ __all__ = [ 'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d', 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat', - 'partial_sum', 'tdm_child' + 'partial_sum', 'tdm_child', 'rank_attention' ] @@ -1017,3 +1017,65 @@ def tdm_child(x, node_nums, child_nums, param_attr=None, dtype='int32'): 'dtype': c_dtype}, stop_gradient=True) return (child, leaf_mask) + + +def rank_attention(input, + rank_offset, + rank_param_shape, + rank_param_attr, + max_rank=3): + """ + **Rank Attention layer** + This Op can calculate rank attention between input and rank_param, and + rank_param gives the organization of data. Notice: It currently supports + GPU device. + This Op exists in contrib, which means that it is not shown to the public. + Args: + input: Tensor with data type float32, float64. + rank_offset: Tensor with data type int32. + rank_para_shape: The shape of rank_param. + rank_param_attr: Attribute initializer of rank_param. + max_rank: The max rank of input's ranks. + Returns: + Variable: A Tensor with the same data type as input's. + Examples: + .. code-block:: python + import paddle.fluid as fluid + import numpy as np + + input = fluid.data(name="input", shape=[None, 2], dtype="float32") + rank_offset = fluid.data(name="rank_offset", shape=[None, 7], dtype="int32") + out = fluid.contrib.layers.rank_attention(input=input, + rank_offset=rank_offset, + rank_param_shape=[18,3], + rank_param_attr= + fluid.ParamAttr(learning_rate=1.0, + name="ubm_rank_param.w_0", + initializer= + fluid.initializer.Xavier(uniform=False)), + max_rank=3) + """ + helper = LayerHelper('rank_attention', **locals()) + dtype = helper.input_dtype(input_param_name='input') + input_shape = input.shape + assert input_shape[1] * max_rank * max_rank == rank_param_shape[0] + + rank_param = helper.create_parameter( + attr=rank_param_attr, shape=rank_param_shape, dtype=dtype) + rank_param.stop_gradient = False + + output = helper.create_variable_for_type_inference(dtype) + ins_rank = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + + helper.append_op( + type="rank_attention", + inputs={ + "X": input, + "RankOffset": rank_offset, + "RankParam": rank_param + }, + outputs={"Out": output}, + attrs={"MaxRank": max_rank}) + + return output diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7c191cd9504..b2983234b84 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -52,6 +52,7 @@ endif() if (NOT ${WITH_GPU}) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) + LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future elseif(${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index a2f8bc56404..666e7c86bb2 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3030,6 +3030,22 @@ class TestBook(LayerTest): [x, y], start_index=0, length=2) return (sum) + def test_rank_attention(self): + with self.static_graph(): + input = fluid.data(name="input", shape=[None, 2], dtype="float32") + rank_offset = fluid.data( + name="rank_offset", shape=[None, 7], dtype="int32") + out = fluid.contrib.layers.rank_attention( + input=input, + rank_offset=rank_offset, + rank_param_shape=[18, 3], + rank_param_attr=fluid.ParamAttr( + learning_rate=1.0, + name="ubm_rank_param.w_0", + initializer=fluid.initializer.Xavier(uniform=False)), + max_rank=3) + return (out) + def test_roi_pool(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): diff --git a/python/paddle/fluid/tests/unittests/test_rank_attention_op.py b/python/paddle/fluid/tests/unittests/test_rank_attention_op.py new file mode 100644 index 00000000000..f9b5afb22d5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rank_attention_op.py @@ -0,0 +1,221 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import random +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid.core as core + + +def gen_input_help(input, rank_offset, max_rank): + input_row, input_col = input.shape + input_help = np.zeros((input_row * max_rank * input_col, )) + ins_rank = np.zeros((input_row, 1)) + ins_rank.fill(-1) + + output_col = max_rank * input_col + output_row = input_row + + for idx in range(output_col * output_row): + output_col_idx = idx % output_col + output_row_idx = int(idx / output_col) + k = int(output_col_idx / input_col) + faster = rank_offset[output_row_idx, 2 * k + 1] - 1 + + if output_col_idx == 0: + ins_rank[output_row_idx] = rank_offset[output_row_idx, 0] + + if rank_offset[output_row_idx, 0] - 1 < 0 or faster < 0: + continue + + rank_input_col_idx = output_col_idx % input_col + index = rank_offset[output_row_idx, 2 * k + 2] + input_help[idx] = input[index, rank_input_col_idx] + input_help = input_help.reshape([input_row, max_rank * input_col]) + + return input_help, ins_rank + + +def gen_param_help(input, rank_offset, param, max_rank): + input_row, input_col = input.shape + rank_offset_row, rank_offset_col = rank_offset.shape + param_row, param_col = param.shape + + block_matrix_row = input_col * max_rank + + output_param_row = block_matrix_row * input_row + output_param_col = param_col + + output_param = np.zeros((output_param_row * output_param_col, )) + + for idx in range(output_param_row * output_param_col): + output_col_idx = idx % output_param_col + output_row_idx = int(idx / output_param_col) + ins_idx = int(output_row_idx / block_matrix_row) + start_offset = output_row_idx % block_matrix_row + k = int(start_offset / input_col) + k_offset = start_offset % input_col + + lower = rank_offset[ins_idx, 0] - 1 + faster = rank_offset[ins_idx, 2 * k + 1] - 1 + if lower < 0 or faster < 0: + continue + start = lower * max_rank + faster + ori_idx = start * param_col * input_col + k_offset * param_col + output_col_idx + output_param[idx] = param[int(ori_idx / param_col), ori_idx % param_col] + + output_param = output_param.reshape([output_param_row, output_param_col]) + return output_param + + +def np_rank_attention(input, rank_offset, rank_para, max_rank): + input_row, input_col = input.shape + rank_offset_row, rank_offset_col = rank_offset.shape + rank_para_row, rank_para_col = rank_para.shape + + assert (input_row == rank_offset_row) + assert (max_rank == ((rank_offset_col - 1) / 2)) + assert (rank_para_row == max_rank * max_rank * input_col) + + input_help, ins_rank = gen_input_help(input, rank_offset, max_rank) + param_help = gen_param_help(input, rank_offset, rank_para, max_rank) + block_matrix_row = input_col * max_rank + + res = np.zeros((input_row, rank_para_col)) + for ins in range(input_row): + res[ins, :] = \ + np.dot(input_help[ins, :], + param_help[int(block_matrix_row * ins):int(block_matrix_row * (ins+1)),:]) + return res, input_help, param_help, ins_rank + + +def gen_rank_offset(pv_nums, max_rank): + all_ins_num = 0 + pv_rank_msg = [] + for _ in range(pv_nums): + ins_pv = np.random.randint(1, max_rank + 2) # 1~4 + rank_list = list(range(1, ins_pv + 1)) + random.shuffle(rank_list) + all_ins_num = all_ins_num + ins_pv + pv_rank_msg.append(rank_list) + + rank_offset = np.zeros((all_ins_num, max_rank * 2 + 1)).astype("int32") + rank_offset.fill(-1) + index = 0 + for pv_number in range(len(pv_rank_msg)): + pv_ins = pv_rank_msg[pv_number] + ad_num = len(pv_ins) + index_start = index + + for j in range(ad_num): + rank = -1 + if pv_ins[j] <= max_rank: + rank = pv_ins[j] + rank_offset[index, 0] = rank + + if rank > 0: + for k in range(ad_num): + fast_rank = -1 + if pv_ins[k] <= max_rank: + fast_rank = pv_ins[k] + if fast_rank > 0: + m = fast_rank - 1 + rank_offset[index, 2 * m + 1] = pv_ins[k] + rank_offset[index, 2 * m + 2] = index_start + k + index = index + 1 + return all_ins_num, rank_offset + + +class TestRankAttentionOpComplex(OpTest): + def config(self): + self.pv_num = 100 + self.x_feat = 10 + self.y_feat = 15 + self.max_rank = 3 + self.dtype = "float64" + + def setUp(self): + self.op_type = "rank_attention" + self.config() + ins_num, rank_offset = gen_rank_offset(self.pv_num, self.max_rank) + input = np.random.random((ins_num, self.x_feat)).astype(self.dtype) + rank_para_shape = [ + self.max_rank * self.max_rank * self.x_feat, self.y_feat + ] + rank_para = np.random.random(rank_para_shape).astype(self.dtype) + np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention( + input, np.array(rank_offset), rank_para, self.max_rank) + self.inputs = { + "X": input, + "RankOffset": np.array(rank_offset).astype("int32"), + "RankParam": rank_para + } + self.attrs = {'MaxRank': self.max_rank} + self.outputs = {"Out": np_out} + + def test_check_output_gpu(self): + if core.is_compiled_with_cuda(): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad_gpu(self): + if core.is_compiled_with_cuda(): + self.check_grad_with_place(core.CUDAPlace(0), ["RankParam"], "Out") + + +class TestRankAttentionOpCpu(OpTest): + def config(self): + self.pv_num = 100 + self.x_feat = 10 + self.y_feat = 15 + self.max_rank = 3 + self.dtype = "float64" + + def setUp(self): + self.op_type = "rank_attention" + self.config() + ins_num, rank_offset = gen_rank_offset(self.pv_num, self.max_rank) + input = np.random.random((ins_num, self.x_feat)).astype(self.dtype) + rank_para_shape = [ + self.max_rank * self.max_rank * self.x_feat, self.y_feat + ] + rank_para = np.random.random(rank_para_shape).astype(self.dtype) + np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention( + input, np.array(rank_offset), rank_para, self.max_rank) + self.inputs = { + "X": input, + "RankOffset": np.array(rank_offset).astype("int32"), + "RankParam": rank_para + } + self.attrs = {'MaxRank': self.max_rank} + self.outputs = {"Out": np_out} + + def test_check_output_cpu(self): + try: + self.check_output_with_place(place=core.CPUPlace()) + except: + print("do not support cpu test, skip") + + def test_check_grad_cpu(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["RankParam"], "Out") + except: + print("do not support cpu test, skip") + + +if __name__ == "__main__": + unittest.main() -- GitLab