未验证 提交 f2d1f284 编写于 作者: W wangzhen38 提交者: GitHub

[bug fix] fix pow composite (#52645)

* [bug fix] fix pow composite

* [bug fix] for ci
上级 58d5af00
......@@ -97,10 +97,9 @@ REGISTER_OPERATOR(elementwise_pow,
ops::ElementwisePowOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwisePowOpGradMaker<paddle::framework::OpDesc>,
ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_pow_grad,
ops::ElementwiseOpGrad,
ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>,
ops::ElementwisePowCompositeGradOpMaker);
REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad);
REGISTER_OP_VERSION(elementwise_pow)
.AddCheckpoint(
......
......@@ -398,7 +398,7 @@ void elementwise_pow_grad(const Tensor& x,
// dy = lnx * x^y
auto lnx = log<T>(x);
auto x_pow_y = elementwise_pow<T>(x, y);
auto dy_res = lnx * x_pow_y;
auto dy_res = lnx * x_pow_y * out_grad;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
......@@ -418,7 +418,7 @@ void elementwise_pow_grad(const Tensor& x,
// dx = y * x^(y-1)
auto tmp_z = y - 1.0;
auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z;
auto dx_res = y * x_pow_z * out_grad;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册