diff --git a/paddle/fluid/operators/rank_attention_op.cc b/paddle/fluid/operators/rank_attention_op.cc index 76a04014e4e11242915c4dd26f45ea83899abc05..460df0333f841953b4f2e506f67c04cb570dc927 100644 --- a/paddle/fluid/operators/rank_attention_op.cc +++ b/paddle/fluid/operators/rank_attention_op.cc @@ -34,6 +34,14 @@ class RankAttentionOp : public framework::OperatorWithKernel { ctx->HasInput("RankParam"), true, platform::errors::InvalidArgument( "Input(RankParam) of RankAttentionOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("InsRank"), true, + platform::errors::InvalidArgument( + "Output(InsRank) of RankAttentionOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("InputHelp"), true, + platform::errors::InvalidArgument( + "Output(InputHelp) of RankAttentionOp should not be null.")); PADDLE_ENFORCE_EQ( ctx->HasOutput("Out"), true, platform::errors::InvalidArgument( @@ -45,12 +53,16 @@ class RankAttentionOp : public framework::OperatorWithKernel { auto param_dims = ctx->GetInputDim("RankParam"); auto para_col = param_dims[1]; auto rank_offset_dims = ctx->GetInputDim("RankOffset"); + auto x_fea_dim = x_dims[1]; + auto block_matrix_row = max_rank * x_fea_dim; PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank, platform::errors::InvalidArgument( "Input(RankOffset) has wrong columns.")); ctx->SetOutputDim("Out", {ins_num, para_col}); + ctx->SetOutputDim("InputHelp", {ins_num, block_matrix_row}); + ctx->SetOutputDim("InsRank", {ins_num, 1}); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -77,6 +89,12 @@ class RankAttentionGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasInput("RankOffset"), true, platform::errors::InvalidArgument( "Input(RankOffset) should not be null")); + PADDLE_ENFORCE_EQ(ctx->HasInput("InputHelp"), true, + platform::errors::InvalidArgument( + "Input(InputHelp) should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("InsRank"), true, + platform::errors::InvalidArgument("Input(InsRank) should not be null")); ctx->SetOutputDim(framework::GradVarName("RankParam"), ctx->GetInputDim("RankParam")); @@ -99,9 +117,15 @@ class RankAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) Input tensor of rank_attention_Op operator."); AddInput("RankParam", "(Tensor) Input tensor of rank_attention_Op operator."); + AddOutput("InputHelp", "Output tensor of rank_attention_Op operator.") + .AsDispensable(); AddOutput("Out", "Output tensor of rank_attention_Op operator."); + AddOutput("InsRank", "Output tensor of rank_attention_Op operator.") + .AsDispensable(); AddAttr("MaxRank", "(int, default 3) max rank of rank_attention_Op") .SetDefault(3); + AddAttr("MaxSize", "(int, default 0) max rank of rank_attention_Op") + .SetDefault(0); AddComment(R"DOC( RankAttention Operator. This Op can calculate rank attention between input and rank_param, @@ -123,7 +147,9 @@ class RankAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("X", this->Input("X")); op->SetInput("RankOffset", this->Input("RankOffset")); op->SetInput("RankParam", this->Input("RankParam")); + op->SetInput("InputHelp", this->Output("InputHelp")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("InsRank", this->Output("InsRank")); op->SetOutput(framework::GradVarName("RankParam"), this->InputGrad("RankParam")); @@ -131,7 +157,8 @@ class RankAttentionGradOpMaker : public framework::SingleGradOpMaker { } }; DECLARE_NO_NEED_BUFFER_VARS_INFERER( - RankAttentionGradOpNoNeedBufferVarsInference, "RankParam"); + RankAttentionGradOpNoNeedBufferVarsInference, "X", "RankOffset", + "RankParam"); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/rank_attention_op.cu b/paddle/fluid/operators/rank_attention_op.cu index 08e2a9ccca4dcb701e14117c38d229bf9d3bfc72..6c242e156a5b4becce6d686e2cccc18353caa4f7 100644 --- a/paddle/fluid/operators/rank_attention_op.cu +++ b/paddle/fluid/operators/rank_attention_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/rank_attention.cu.h" @@ -32,7 +33,10 @@ class RankAttentionCUDAKernel : public framework::OpKernel { auto *X = ctx.Input("X"); auto *rank_offset = ctx.Input("RankOffset"); auto *param = ctx.Input("RankParam"); + auto *input_help = ctx.Output("InputHelp"); + auto *ins_rank = ctx.Output("InsRank"); int max_rank = ctx.Attr("MaxRank"); + int64_t max_size = ctx.Attr("MaxSize"); auto *Out = ctx.Output("Out"); // check dims @@ -56,37 +60,42 @@ class RankAttentionCUDAKernel : public framework::OpKernel { int block_matrix_row = max_rank * x_fea_dim; auto &dev_ctx = ctx.template device_context(); - auto stream = ctx.cuda_device_context().stream(); - int device_id = platform::GetCurrentDeviceId(); - - T *param_help_data; - auto param_help_size = ins_num * block_matrix_row * para_col * sizeof(T); - platform::RecordedCudaMalloc(reinterpret_cast(¶m_help_data), - param_help_size, device_id); - platform::GpuMemsetAsync(param_help_data, 0, param_help_size, stream); - - T *input_help_data; - auto input_help_size = ins_num * block_matrix_row * sizeof(T); - platform::RecordedCudaMalloc(reinterpret_cast(&input_help_data), - input_help_size, device_id); - platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream); - - T *ins_rank_data; - auto ins_rank_size = ins_num * sizeof(T); - platform::RecordedCudaMalloc(reinterpret_cast(&ins_rank_data), - ins_rank_size, device_id); - platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream); + int max_ins = std::max(ins_num, max_size); + + Tensor param_help; + param_help = ctx.AllocateTmpTensor( + {max_ins * block_matrix_row, para_col}, dev_ctx); + param_help.mutable_data(ctx.GetPlace()); + + input_help->Resize({max_ins, block_matrix_row}); + ins_rank->Resize({max_ins, 1}); + input_help->mutable_data(ctx.GetPlace()); + ins_rank->mutable_data(ctx.GetPlace()); Out->mutable_data(ctx.GetPlace()); // initialize + auto param_help_eigen = framework::EigenVector::Flatten(param_help); + auto input_help_eigen = framework::EigenVector::Flatten(*input_help); + auto ins_rank_eigen = framework::EigenVector::Flatten(*ins_rank); auto out_eigen = framework::EigenVector::Flatten(*Out); + auto &place = *ctx.template device_context() .eigen_device(); + + param_help_eigen.device(place) = + param_help_eigen.constant(static_cast(0)); + input_help_eigen.device(place) = + input_help_eigen.constant(static_cast(0)); + ins_rank_eigen.device(place) = ins_rank_eigen.constant(static_cast(-1)); out_eigen.device(place) = out_eigen.constant(static_cast(0)); // get data ptr + T *input_help_data = input_help->data(); + T *param_help_data = param_help.data(); + T *ins_rank_data = ins_rank->data(); T *out_data = Out->data(); + expand_rank_attention_input( ctx.cuda_device_context().stream(), X->data(), ins_num, x_fea_dim, input_help_data, ins_num, block_matrix_row, rank_offset->data(), @@ -110,10 +119,6 @@ class RankAttentionCUDAKernel : public framework::OpKernel { blas.BatchedGEMM(transA, transB, 1, para_col, block_matrix_row, alpha, input_help_data, param_help_data, beta, out_data, ins_num, strideA, strideB); - - platform::RecordedCudaFree(param_help_data, param_help_size, device_id); - platform::RecordedCudaFree(input_help_data, input_help_size, device_id); - platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id); } }; @@ -121,10 +126,13 @@ template class RankAttentionGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto *X = ctx.Input("X"); - auto *rank_offset = ctx.Input("RankOffset"); - auto *param = ctx.Input("RankParam"); + auto *X = ctx.Input("X"); // not use data + auto *rank_offset = ctx.Input("RankOffset"); // not use data + auto *param = ctx.Input("RankParam"); // not use data + auto *input_help = ctx.Input("InputHelp"); + auto *ins_rank = ctx.Input("InsRank"); auto *dout = ctx.Input(framework::GradVarName("Out")); + int64_t max_size = ctx.Attr("MaxSize"); auto *drank_para = ctx.Output(framework::GradVarName("RankParam")); @@ -142,38 +150,26 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel { auto &place = *ctx.template device_context() .eigen_device(); + int max_ins = std::max(ins_num, max_size); // initialize out grad drank_para->mutable_data(ctx.GetPlace()); auto drank_para_eigen = framework::EigenVector::Flatten(*drank_para); drank_para_eigen.device(place) = drank_para_eigen.constant(static_cast(0)); - auto stream = ctx.cuda_device_context().stream(); - int device_id = platform::GetCurrentDeviceId(); - - T *param_grad_data; - auto param_grad_size = ins_num * block_matrix_row * para_col * sizeof(T); - platform::RecordedCudaMalloc(reinterpret_cast(¶m_grad_data), - param_grad_size, device_id); - platform::GpuMemsetAsync(param_grad_data, 0, param_grad_size, stream); - - T *input_help_data; - auto input_help_size = ins_num * block_matrix_row * sizeof(T); - platform::RecordedCudaMalloc(reinterpret_cast(&input_help_data), - input_help_size, device_id); - platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream); - - T *ins_rank_data; - auto ins_rank_size = ins_num * sizeof(T); - platform::RecordedCudaMalloc(reinterpret_cast(&ins_rank_data), - ins_rank_size, device_id); - platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream); - - // expand input - expand_rank_attention_input( - ctx.cuda_device_context().stream(), X->data(), ins_num, x_fea_dim, - input_help_data, ins_num, block_matrix_row, rank_offset->data(), - rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank); + // copy data + Tensor param_grad; + param_grad = ctx.AllocateTmpTensor( + {max_ins * block_matrix_row, para_col}, dev_ctx); + param_grad.mutable_data(ctx.GetPlace()); + // initialize + auto param_grad_eigen = framework::EigenVector::Flatten(param_grad); + param_grad_eigen.device(place) = + param_grad_eigen.constant(static_cast(0)); + // get data ptr + const T *input_help_data = input_help->data(); + const T *ins_rank_data = ins_rank->data(); + T *param_grad_data = param_grad.data(); auto blas = math::GetBlas(dev_ctx); T alpha = 1; @@ -184,20 +180,14 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel { CBLAS_TRANSPOSE transB = CblasNoTrans; int64_t strideA = block_matrix_row; int64_t strideB = para_col; - blas.BatchedGEMM(transA, transB, block_matrix_row, para_col, 1, alpha, input_help_data, dout->data(), beta, param_grad_data, ins_num, strideA, strideB); - // merge param_grad to get drank_para merge_rank_attention_param_grad( ctx.cuda_device_context().stream(), param_grad_data, ins_num * block_matrix_row, para_col, drank_para->data(), para_row, para_col, ins_rank_data, ins_num, max_rank, x_fea_dim); - - platform::RecordedCudaFree(param_grad_data, param_grad_size, device_id); - platform::RecordedCudaFree(input_help_data, input_help_size, device_id); - platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index c7c4ea8ec1cf8729a7017ee3a634590c5792cdfa..3b6372c000cbd25a8d6425a11e8762b0471c06e2 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1236,7 +1236,8 @@ def rank_attention(input, rank_offset, rank_param_shape, rank_param_attr, - max_rank=3): + max_rank=3, + max_size=0): """ **Rank Attention layer** This Op can calculate rank attention between input and rank_param, and @@ -1266,7 +1267,8 @@ def rank_attention(input, name="ubm_rank_param.w_0", initializer= fluid.initializer.Xavier(uniform=False)), - max_rank=3) + max_rank=3, + max_size=0) """ helper = LayerHelper('rank_attention', **locals()) dtype = helper.input_dtype(input_param_name='input') @@ -1278,6 +1280,8 @@ def rank_attention(input, rank_param.stop_gradient = False output = helper.create_variable_for_type_inference(dtype) + input_help = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) ins_rank = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True) @@ -1288,7 +1292,9 @@ def rank_attention(input, "RankOffset": rank_offset, "RankParam": rank_param }, - outputs={"Out": output}, - attrs={"MaxRank": max_rank}) - + outputs={"Out": output, + "InputHelp": input_help, + "InsRank": ins_rank}, + attrs={"MaxRank": max_rank, + "MaxSize": max_size}) return output diff --git a/python/paddle/fluid/tests/unittests/test_rank_attention_op.py b/python/paddle/fluid/tests/unittests/test_rank_attention_op.py index f9b5afb22d5f843512cbdd70c29d9656f197b89d..64d564c223f8d6e462d43dda2270d310d561ccad 100644 --- a/python/paddle/fluid/tests/unittests/test_rank_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_rank_attention_op.py @@ -22,10 +22,11 @@ from op_test import OpTest, skip_check_grad_ci import paddle.fluid.core as core -def gen_input_help(input, rank_offset, max_rank): +def gen_input_help(input, rank_offset, max_rank, max_size): input_row, input_col = input.shape - input_help = np.zeros((input_row * max_rank * input_col, )) - ins_rank = np.zeros((input_row, 1)) + max_ins = np.max((max_size, input_row)) + input_help = np.zeros((max_ins * max_rank * input_col)) + ins_rank = np.zeros((max_ins, 1)) ins_rank.fill(-1) output_col = max_rank * input_col @@ -46,7 +47,7 @@ def gen_input_help(input, rank_offset, max_rank): rank_input_col_idx = output_col_idx % input_col index = rank_offset[output_row_idx, 2 * k + 2] input_help[idx] = input[index, rank_input_col_idx] - input_help = input_help.reshape([input_row, max_rank * input_col]) + input_help = input_help.reshape([max_ins, max_rank * input_col]) return input_help, ins_rank @@ -83,7 +84,7 @@ def gen_param_help(input, rank_offset, param, max_rank): return output_param -def np_rank_attention(input, rank_offset, rank_para, max_rank): +def np_rank_attention(input, rank_offset, rank_para, max_rank, max_size): input_row, input_col = input.shape rank_offset_row, rank_offset_col = rank_offset.shape rank_para_row, rank_para_col = rank_para.shape @@ -92,7 +93,8 @@ def np_rank_attention(input, rank_offset, rank_para, max_rank): assert (max_rank == ((rank_offset_col - 1) / 2)) assert (rank_para_row == max_rank * max_rank * input_col) - input_help, ins_rank = gen_input_help(input, rank_offset, max_rank) + input_help, ins_rank = gen_input_help(input, rank_offset, max_rank, + max_size) param_help = gen_param_help(input, rank_offset, rank_para, max_rank) block_matrix_row = input_col * max_rank @@ -159,14 +161,19 @@ class TestRankAttentionOpComplex(OpTest): ] rank_para = np.random.random(rank_para_shape).astype(self.dtype) np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention( - input, np.array(rank_offset), rank_para, self.max_rank) + input, + np.array(rank_offset), rank_para, self.max_rank, self.pv_num * 7) self.inputs = { "X": input, "RankOffset": np.array(rank_offset).astype("int32"), "RankParam": rank_para } - self.attrs = {'MaxRank': self.max_rank} - self.outputs = {"Out": np_out} + self.attrs = {'MaxRank': self.max_rank, 'MaxSize': self.pv_num * 7} + self.outputs = { + "Out": np_out, + "InputHelp": np_input_help, + "InsRank": np_ins_rank + } def test_check_output_gpu(self): if core.is_compiled_with_cuda(): @@ -195,14 +202,19 @@ class TestRankAttentionOpCpu(OpTest): ] rank_para = np.random.random(rank_para_shape).astype(self.dtype) np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention( - input, np.array(rank_offset), rank_para, self.max_rank) + input, + np.array(rank_offset), rank_para, self.max_rank, self.pv_num * 7) self.inputs = { "X": input, "RankOffset": np.array(rank_offset).astype("int32"), "RankParam": rank_para } - self.attrs = {'MaxRank': self.max_rank} - self.outputs = {"Out": np_out} + self.attrs = {'MaxRank': self.max_rank, 'MaxSize': self.pv_num * 7} + self.outputs = { + "Out": np_out, + "InputHelp": np_input_help, + "InsRank": np_ins_rank + } def test_check_output_cpu(self): try: