diff --git a/paddle/fluid/operators/concat_op_mlu.cc b/paddle/fluid/operators/concat_op_mlu.cc index 63f4ec46599bac4a3118d06330fc10f91cff061a..e8f6b2dc86952234a9625810863f8b27b0b641f7 100644 --- a/paddle/fluid/operators/concat_op_mlu.cc +++ b/paddle/fluid/operators/concat_op_mlu.cc @@ -74,6 +74,64 @@ class ConcatMLUKernel : public framework::OpKernel { output_desc.get(), GetBasePtr(out)); } }; + +template +class ConcatGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); + auto outs = + ctx.MultiOutput(framework::GradVarName("X")); + auto axis = ctx.Attr("axis"); + int split_num = ins.size(); + + PADDLE_ENFORCE_NOT_NULL(ins[0], + platform::errors::NotFound( + "The first input tensor is not initalized.")); + + if (ctx.HasInput("AxisTensor")) { + auto* axis_tensor = ctx.Input("AxisTensor"); + axis = GetDataFromTensor(axis_tensor)[0]; + } + + axis = ComputeAxis(static_cast(axis), + static_cast(ins[0]->dims().size())); + PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument( + "concat_grad: axis should be larger than or " + "equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT( + axis, out_grad->dims().size(), + platform::errors::InvalidArgument( + "concat_grad: axis should be less than ins[0]->dims()!" + "But received axis is %d, while ins[0]->dims()" + "size is %d.", + axis, out_grad->dims().size())); + // get output tensor that the name is not kEmptyVarName + std::vector outputs_vec; + std::vector output_descs; + std::vector descs_vec; + for (size_t j = 0; j < outs.size(); ++j) { + if (out_var_names[j] != framework::kEmptyVarName && + outs[j]->numel() != 0UL) { + outs[j]->mutable_data(ctx.GetPlace()); + output_descs.emplace_back(MLUCnnlTensorDesc(*outs[j])); + descs_vec.push_back(output_descs.back().get()); + outputs_vec.push_back(GetBasePtr(outs[j])); + } else { + outputs_vec.push_back(nullptr); + } + } + + MLUCnnlTensorDesc out_grad_desc(*out_grad); + MLUCnnl::Split(ctx, static_cast(split_num), static_cast(axis), + out_grad_desc.get(), GetBasePtr(out_grad), descs_vec.data(), + outputs_vec.data()); + } +}; } // namespace operators } // namespace paddle @@ -84,3 +142,9 @@ REGISTER_OP_MLU_KERNEL(concat, ops::ConcatMLUKernel, ops::ConcatMLUKernel, ops::ConcatMLUKernel, ops::ConcatMLUKernel, ops::ConcatMLUKernel); +REGISTER_OP_MLU_KERNEL(concat_grad, ops::ConcatGradMLUKernel, + ops::ConcatGradMLUKernel, + ops::ConcatGradMLUKernel, + ops::ConcatGradMLUKernel, + ops::ConcatGradMLUKernel, + ops::ConcatGradMLUKernel);