From f2d1f2846ea5951daec871ac65bd813b40106b4d Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Mon, 10 Apr 2023 09:39:08 +0800 Subject: [PATCH] [bug fix] fix pow composite (#52645) * [bug fix] fix pow composite * [bug fix] for ci --- paddle/fluid/operators/elementwise/elementwise_pow_op.cc | 5 ++--- .../prim/api/composite_backward/composite_backward_api.h | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc index 0273743c95a..b6389b97476 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc @@ -97,10 +97,9 @@ REGISTER_OPERATOR(elementwise_pow, ops::ElementwisePowOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwisePowOpGradMaker, - ops::ElementwisePowOpGradMaker); -REGISTER_OPERATOR(elementwise_pow_grad, - ops::ElementwiseOpGrad, + ops::ElementwisePowOpGradMaker, ops::ElementwisePowCompositeGradOpMaker); +REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad); REGISTER_OP_VERSION(elementwise_pow) .AddCheckpoint( diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 0ac61cd4a3e..286d3cae8de 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -398,7 +398,7 @@ void elementwise_pow_grad(const Tensor& x, // dy = lnx * x^y auto lnx = log(x); auto x_pow_y = elementwise_pow(x, y); - auto dy_res = lnx * x_pow_y; + auto dy_res = lnx * x_pow_y * out_grad; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); @@ -418,7 +418,7 @@ void elementwise_pow_grad(const Tensor& x, // dx = y * x^(y-1) auto tmp_z = y - 1.0; auto x_pow_z = elementwise_pow(x, tmp_z); - auto dx_res = y * x_pow_z; + auto dx_res = y * x_pow_z * out_grad; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); -- GitLab