未验证 提交 cf3ddf24 编写于 作者: J Jiabin Yang 提交者: GitHub

【Eager】fix multiply double grad error (#52870)

* fix multiply double grad error

* fix multiply dy only kenrel
上级 3c44e948
......@@ -529,6 +529,44 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dout, ddy_safe, dx, axis);
} else if ((!dx) && dy) {
DenseTensor tmp_a(ddout->dtype());
tmp_a.Resize(ddout->dims());
dev_ctx.template Alloc<T>(&tmp_a);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, x, ddy_safe, &tmp_a, axis);
auto ddout_t1 = phi::EigenVector<T>::Flatten(tmp_a);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, ddx_safe, y, ddout, axis);
auto ddout_t2 = phi::EigenVector<T>::Flatten(*ddout);
ddout_t2.device(place) = ddout_t2 + ddout_t1;
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx,
ddx_safe,
ddy_safe,
dout,
dout,
axis,
nullptr,
dy,
MulGradDX<T>(),
MulGradDY<T>());
} else {
DenseTensor tmp_a(ddout->dtype());
tmp_a.Resize(ddout->dims());
......@@ -554,7 +592,7 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
}
}
} else {
if (dx && dy) {
VLOG(3) << "Calculating here with dx: " << dx << ", dy: " << dy;
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx,
ddx_safe,
......@@ -567,7 +605,6 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
MulGradDX<T>(),
MulGradDY<T>());
}
}
}
template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册