From 1ef1cace6bfc208519bac8f31d864bc6affc93fb Mon Sep 17 00:00:00 2001 From: Lux et Veritas <1004239791@qq.com> Date: Thu, 29 Sep 2022 11:19:18 +0800 Subject: [PATCH] [MLU] add mlu kernel for add_reduce_max_grad (#45651) Co-authored-by: liupeiyu --- .../operators/reduce_ops/reduce_max_op_mlu.cc | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc index 1ece3bdf72..a23931c0aa 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc @@ -92,6 +92,112 @@ class ReduceMaxMLUKernel : public framework::OpKernel { } }; +template +class ReduceMaxGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto reduce_dims = context.Attr>("dim"); + bool reduce_all = context.Attr("reduce_all"); + int in_dtype = context.Attr("in_dtype"); + + PADDLE_ENFORCE_EQ( + in_dtype == -1, + true, + platform::errors::InvalidArgument( + "MLU only support in_dtype == -1 in reduce_max_grad op.")); + auto* x_grad = context.Output(framework::GradVarName("X")); + x_grad->mutable_data(context.GetPlace()); + + auto place = context.GetPlace(); + + // broadcast + auto x_dims_vec = phi::vectorize(x->dims()); + if (reduce_all) { + reduce_dims.clear(); + for (size_t d = 0; d < x_dims_vec.size(); ++d) { + reduce_dims.push_back(static_cast(d)); + } + } + + Tensor tmp_out, tmp_out_grad; + auto tmp_out_dims_vec = x_dims_vec; + for (auto d : reduce_dims) { + if (d < 0) { + d += x_dims_vec.size(); + } + tmp_out_dims_vec[d] = 1; + } + + tmp_out.ShareDataWith(*out); + tmp_out.Resize(phi::make_ddim(tmp_out_dims_vec)); + tmp_out_grad.ShareDataWith(*out_grad); + tmp_out_grad.Resize(phi::make_ddim(tmp_out_dims_vec)); + + Tensor transformed_out(x->type()); + transformed_out.Resize(phi::make_ddim(x_dims_vec)); + transformed_out.mutable_data(place); + + MLUCnnlTensorDesc tmp_out_desc(tmp_out); + MLUCnnlTensorDesc transformed_out_desc(transformed_out); + + MLUCnnl::BroadcastTo(context, + tmp_out_desc.get(), + GetBasePtr(&tmp_out), + transformed_out_desc.get(), + GetBasePtr(&transformed_out)); + + Tensor transformed_out_grad(x->type()); + transformed_out_grad.Resize(phi::make_ddim(x_dims_vec)); + transformed_out_grad.mutable_data(place); + MLUCnnlTensorDesc tmp_out_grad_desc(tmp_out_grad); + MLUCnnlTensorDesc transformed_out_grad_desc(transformed_out_grad); + + MLUCnnl::BroadcastTo(context, + tmp_out_grad_desc.get(), + GetBasePtr(&tmp_out_grad), + transformed_out_grad_desc.get(), + GetBasePtr(&transformed_out_grad)); + + // compare + Tensor equal_cond; + equal_cond.mutable_data(x_grad->dims(), place); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc equal_cond_desc(equal_cond); + + MLUCnnl::Logic(context, + CNNL_LOGIC_OP_EQ, + x_desc.get(), + GetBasePtr(x), + transformed_out_desc.get(), + GetBasePtr(&transformed_out), + equal_cond_desc.get(), + GetBasePtr(&equal_cond)); + + // select + Tensor t_zero; + t_zero.mutable_data(x_grad->dims(), place); + FillMLUTensorWithHostValue(context, static_cast(0), &t_zero); + t_zero.Resize(x_grad->dims()); + + MLUCnnlTensorDesc t_zero_desc(t_zero); + MLUCnnlTensorDesc x_grad_desc(*x_grad); + + MLUCnnl::Select(context, + equal_cond_desc.get(), + GetBasePtr(&equal_cond), + transformed_out_grad_desc.get(), + GetBasePtr(&transformed_out_grad), + t_zero_desc.get(), + GetBasePtr(&t_zero), + x_grad_desc.get(), + GetBasePtr(x_grad)); + } +}; + } // namespace operators } // namespace paddle @@ -102,3 +208,7 @@ REGISTER_OP_MLU_KERNEL(reduce_max, ops::ReduceMaxMLUKernel, ops::ReduceMaxMLUKernel, ops::ReduceMaxMLUKernel); +REGISTER_OP_MLU_KERNEL(reduce_max_grad, + ops::ReduceMaxGradMLUKernel, + ops::ReduceMaxGradMLUKernel, + ops::ReduceMaxGradMLUKernel); -- GitLab