未验证 提交 f2d1f284 编写于 作者: W wangzhen38 提交者: GitHub

[bug fix] fix pow composite (#52645)

* [bug fix] fix pow composite

* [bug fix] for ci
上级 58d5af00
...@@ -97,10 +97,9 @@ REGISTER_OPERATOR(elementwise_pow, ...@@ -97,10 +97,9 @@ REGISTER_OPERATOR(elementwise_pow,
ops::ElementwisePowOpMaker, ops::ElementwisePowOpMaker,
ops::ElementwiseOpInferVarType, ops::ElementwiseOpInferVarType,
ops::ElementwisePowOpGradMaker<paddle::framework::OpDesc>, ops::ElementwisePowOpGradMaker<paddle::framework::OpDesc>,
ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>); ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>,
REGISTER_OPERATOR(elementwise_pow_grad,
ops::ElementwiseOpGrad,
ops::ElementwisePowCompositeGradOpMaker); ops::ElementwisePowCompositeGradOpMaker);
REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad);
REGISTER_OP_VERSION(elementwise_pow) REGISTER_OP_VERSION(elementwise_pow)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -398,7 +398,7 @@ void elementwise_pow_grad(const Tensor& x, ...@@ -398,7 +398,7 @@ void elementwise_pow_grad(const Tensor& x,
// dy = lnx * x^y // dy = lnx * x^y
auto lnx = log<T>(x); auto lnx = log<T>(x);
auto x_pow_y = elementwise_pow<T>(x, y); auto x_pow_y = elementwise_pow<T>(x, y);
auto dy_res = lnx * x_pow_y; auto dy_res = lnx * x_pow_y * out_grad;
if (x.dims() != y.dims()) { if (x.dims() != y.dims()) {
// Maybe need reduce here // Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
...@@ -418,7 +418,7 @@ void elementwise_pow_grad(const Tensor& x, ...@@ -418,7 +418,7 @@ void elementwise_pow_grad(const Tensor& x,
// dx = y * x^(y-1) // dx = y * x^(y-1)
auto tmp_z = y - 1.0; auto tmp_z = y - 1.0;
auto x_pow_z = elementwise_pow<T>(x, tmp_z); auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z; auto dx_res = y * x_pow_z * out_grad;
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
// Maybe need reduce here // Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册