diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc index f99b72faba4aeb0a1013645ae429f13dd7176505..04660fb501142505403276292d40f76f872bd43e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_npu.cc @@ -105,6 +105,68 @@ class ReduceMaxNPUKernel : public framework::OpKernel { } }; +template +class ReduceMaxGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + int in_dtype = context.Attr("in_dtype"); + + PADDLE_ENFORCE_EQ( + in_dtype == -1, true, + platform::errors::InvalidArgument( + "NPU only support in_dtype == -1 in reduce_max_grad op.")); + + auto* x_grad = context.Output(framework::GradVarName("X")); + x_grad->mutable_data(context.GetPlace()); + + auto& dev_ctx = + context.template device_context(); + auto place = context.GetPlace(); + auto stream = dev_ctx.stream(); + + // broadcast + auto x_dims_vec = phi::vectorize(x->dims()); + Tensor transformed_out(x->type()); + transformed_out.Resize(phi::make_ddim(x_dims_vec)); + transformed_out.mutable_data(place); + NpuOpRunner r_brd_out; + r_brd_out.SetType("BroadcastTo") + .AddInput(*out) + .AddInput(std::move(x_dims_vec)) + .AddOutput(transformed_out) + .Run(stream); + Tensor transformed_out_grad(x->type()); + transformed_out_grad.Resize(phi::make_ddim(x_dims_vec)); + transformed_out_grad.mutable_data(place); + NpuOpRunner r_brd_out_grad; + r_brd_out_grad.SetType("BroadcastTo") + .AddInput(*out_grad) + .AddInput(std::move(x_dims_vec)) + .AddOutput(transformed_out_grad) + .Run(stream); + + // compare + Tensor equal_cond; + equal_cond.mutable_data(x_grad->dims(), place); + const auto& r_equal = + NpuOpRunner("Equal", {*x, transformed_out}, {equal_cond}, {}); + r_equal.Run(stream); + + // select + Tensor t_zero; + t_zero.mutable_data(x_grad->dims(), place); + FillNpuTensorWithConstant(&t_zero, static_cast(0)); + t_zero.Resize(x_grad->dims()); + + const auto& r_sel = NpuOpRunner( + "SelectV2", {equal_cond, transformed_out_grad, t_zero}, {*x_grad}, {}); + r_sel.Run(stream); + } +}; + } // namespace operators } // namespace paddle @@ -115,3 +177,8 @@ REGISTER_OP_NPU_KERNEL( ops::ReduceMaxNPUKernel, ops::ReduceMaxNPUKernel, ops::ReduceMaxNPUKernel); +REGISTER_OP_NPU_KERNEL( + reduce_max_grad, ops::ReduceMaxGradNPUKernel, + ops::ReduceMaxGradNPUKernel, + ops::ReduceMaxGradNPUKernel, + ops::ReduceMaxGradNPUKernel);