From 1b71a718f185421f4f7b77faa111781133465e82 Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Wed, 18 Aug 2021 16:44:58 +0800 Subject: [PATCH] [NPU] Add square grad (#34889) * test=develop * test=develop --- paddle/fluid/operators/activation_op_npu.cc | 35 +++++++++++++++++++ .../tests/unittests/npu/test_square_op_npu.py | 10 +++--- 2 files changed, 39 insertions(+), 6 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/activation_op_npu.cc diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc old mode 100755 new mode 100644 index d815a3eeb4..5f2925784e --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -386,6 +386,35 @@ class SquareNPUKernel : public framework::OpKernel { } }; +template +class SquareGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto factor = static_cast(2.0); + + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + // Step 1: Compute x_muls_factor = factor * x + Tensor x_muls_factor(x->type()); + x_muls_factor.mutable_data(x->dims(), place); + const auto& runner_muls_1 = + NpuOpRunner("Muls", {*x}, {x_muls_factor}, {{"value", factor}}); + runner_muls_1.Run(stream); + + // Step 2: Compute dx = dout * factor * x + dx->mutable_data(place); + const auto& runner_mul_2 = + NpuOpRunner("Mul", {*dout, x_muls_factor}, {*dx}, {}); + runner_mul_2.Run(stream); + } +}; + template class SigmoidNPUKernel : public framework::OpKernel { public: @@ -869,6 +898,12 @@ REGISTER_OP_NPU_KERNEL( paddle::platform::float16>, ops::SquareNPUKernel); +REGISTER_OP_NPU_KERNEL( + square_grad, + ops::SquareGradNPUKernel, + ops::SquareNPUKernel); + REGISTER_OP_NPU_KERNEL( sigmoid, ops::SigmoidNPUKernel, ops::SigmoidNPUKernel