未验证 提交 e4247120 编写于 作者: L liym27 提交者: GitHub

[NPU] Fix bug: Fix calculation errors of pow grad npu kernel (#31699)

上级 7ec8459c
......@@ -92,7 +92,7 @@ class PowGradNPUKernel : public framework::OpKernel<T> {
Tensor x_power_mul_factor(x->type());
x_power_mul_factor.mutable_data<T>(x->dims(), place);
auto runner_mul_1 =
NpuOpRunner("Mul", {factor_bc_tensor, *x}, {x_power_mul_factor}, {});
NpuOpRunner("Mul", {factor_bc_tensor, x_pow}, {x_power_mul_factor}, {});
runner_mul_1.Run(stream);
// Step 4: Compute dx = dout * factor * x.pow(factor-1)
......@@ -309,20 +309,17 @@ class SquareNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
pow,
ops::PowNPUKernel<paddle::platform::NPUDeviceContext, float>,
pow, ops::PowNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::PowNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
pow_grad,
ops::PowGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
pow_grad, ops::PowGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::PowGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
relu,
ops::ReluNPUKernel<paddle::platform::NPUDeviceContext, float>,
relu, ops::ReluNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ReluNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......@@ -333,8 +330,7 @@ REGISTER_OP_NPU_KERNEL(
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
sqrt,
ops::SqrtNPUKernel<paddle::platform::NPUDeviceContext, float>,
sqrt, ops::SqrtNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SqrtNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......@@ -345,21 +341,17 @@ REGISTER_OP_NPU_KERNEL(
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
log,
ops::LogNPUKernel<paddle::platform::NPUDeviceContext, float>,
log, ops::LogNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LogNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
log_grad,
ops::LogGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
log_grad, ops::LogGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LogGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
tanh,
ops::TanhNPUKernel<paddle::platform::NPUDeviceContext, float>,
tanh, ops::TanhNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::TanhNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......@@ -370,7 +362,6 @@ REGISTER_OP_NPU_KERNEL(
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
square,
ops::SquareNPUKernel<paddle::platform::NPUDeviceContext, float>,
square, ops::SquareNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SquareNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册