未验证 提交 78d5cf7b 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] add reduce_max_grad op (#42672)

上级 c714926d
...@@ -105,6 +105,68 @@ class ReduceMaxNPUKernel : public framework::OpKernel<T> { ...@@ -105,6 +105,68 @@ class ReduceMaxNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class ReduceMaxGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Input<Tensor>("Out");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
int in_dtype = context.Attr<int>("in_dtype");
PADDLE_ENFORCE_EQ(
in_dtype == -1, true,
platform::errors::InvalidArgument(
"NPU only support in_dtype == -1 in reduce_max_grad op."));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
x_grad->mutable_data<T>(context.GetPlace());
auto& dev_ctx =
context.template device_context<paddle::platform::NPUDeviceContext>();
auto place = context.GetPlace();
auto stream = dev_ctx.stream();
// broadcast
auto x_dims_vec = phi::vectorize(x->dims());
Tensor transformed_out(x->type());
transformed_out.Resize(phi::make_ddim(x_dims_vec));
transformed_out.mutable_data<T>(place);
NpuOpRunner r_brd_out;
r_brd_out.SetType("BroadcastTo")
.AddInput(*out)
.AddInput(std::move(x_dims_vec))
.AddOutput(transformed_out)
.Run(stream);
Tensor transformed_out_grad(x->type());
transformed_out_grad.Resize(phi::make_ddim(x_dims_vec));
transformed_out_grad.mutable_data<T>(place);
NpuOpRunner r_brd_out_grad;
r_brd_out_grad.SetType("BroadcastTo")
.AddInput(*out_grad)
.AddInput(std::move(x_dims_vec))
.AddOutput(transformed_out_grad)
.Run(stream);
// compare
Tensor equal_cond;
equal_cond.mutable_data<bool>(x_grad->dims(), place);
const auto& r_equal =
NpuOpRunner("Equal", {*x, transformed_out}, {equal_cond}, {});
r_equal.Run(stream);
// select
Tensor t_zero;
t_zero.mutable_data<T>(x_grad->dims(), place);
FillNpuTensorWithConstant(&t_zero, static_cast<T>(0));
t_zero.Resize(x_grad->dims());
const auto& r_sel = NpuOpRunner(
"SelectV2", {equal_cond, transformed_out_grad, t_zero}, {*x_grad}, {});
r_sel.Run(stream);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -115,3 +177,8 @@ REGISTER_OP_NPU_KERNEL( ...@@ -115,3 +177,8 @@ REGISTER_OP_NPU_KERNEL(
ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, plat::float16>, ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, int64_t>, ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, int>); ops::ReduceMaxNPUKernel<plat::NPUDeviceContext, int>);
REGISTER_OP_NPU_KERNEL(
reduce_max_grad, ops::ReduceMaxGradNPUKernel<plat::NPUDeviceContext, float>,
ops::ReduceMaxGradNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ReduceMaxGradNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::ReduceMaxGradNPUKernel<plat::NPUDeviceContext, int>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册