From e4247120732e665af866468b5a0971b7bb99f25d Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 17 Mar 2021 21:41:55 +0800 Subject: [PATCH] [NPU] Fix bug: Fix calculation errors of pow grad npu kernel (#31699) --- paddle/fluid/operators/activation_op_npu.cc | 31 ++++++++------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 2f899825c5c..1a843bfc991 100644 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -92,7 +92,7 @@ class PowGradNPUKernel : public framework::OpKernel { Tensor x_power_mul_factor(x->type()); x_power_mul_factor.mutable_data(x->dims(), place); auto runner_mul_1 = - NpuOpRunner("Mul", {factor_bc_tensor, *x}, {x_power_mul_factor}, {}); + NpuOpRunner("Mul", {factor_bc_tensor, x_pow}, {x_power_mul_factor}, {}); runner_mul_1.Run(stream); // Step 4: Compute dx = dout * factor * x.pow(factor-1) @@ -309,20 +309,17 @@ class SquareNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( - pow, - ops::PowNPUKernel, + pow, ops::PowNPUKernel, ops::PowNPUKernel); REGISTER_OP_NPU_KERNEL( - pow_grad, - ops::PowGradNPUKernel, + pow_grad, ops::PowGradNPUKernel, ops::PowGradNPUKernel); REGISTER_OP_NPU_KERNEL( - relu, - ops::ReluNPUKernel, + relu, ops::ReluNPUKernel, ops::ReluNPUKernel); @@ -333,33 +330,28 @@ REGISTER_OP_NPU_KERNEL( paddle::platform::float16>); REGISTER_OP_NPU_KERNEL( - sqrt, - ops::SqrtNPUKernel, + sqrt, ops::SqrtNPUKernel, ops::SqrtNPUKernel); + paddle::platform::float16>); REGISTER_OP_NPU_KERNEL( sqrt_grad, ops::SqrtGradNPUKernel, ops::SqrtGradNPUKernel); + paddle::platform::float16>); REGISTER_OP_NPU_KERNEL( - log, - ops::LogNPUKernel, + log, ops::LogNPUKernel, ops::LogNPUKernel); REGISTER_OP_NPU_KERNEL( - log_grad, - ops::LogGradNPUKernel, + log_grad, ops::LogGradNPUKernel, ops::LogGradNPUKernel); - REGISTER_OP_NPU_KERNEL( - tanh, - ops::TanhNPUKernel, + tanh, ops::TanhNPUKernel, ops::TanhNPUKernel); @@ -370,7 +362,6 @@ REGISTER_OP_NPU_KERNEL( paddle::platform::float16>); REGISTER_OP_NPU_KERNEL( - square, - ops::SquareNPUKernel, + square, ops::SquareNPUKernel, ops::SquareNPUKernel); -- GitLab