未验证 提交 268156f8 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick] Fix the calculation of y_grad in divide_backward (#53672)

上级 3a53b77e
......@@ -116,7 +116,7 @@ struct DivGradXYFunctor {
// dy = - dout * out / y
phi::Array<OutT, 2> outs;
outs[0] = a / c;
outs[1] = -a * b / c;
outs[1] = -a * ((b / c) / c);
return outs;
}
};
......@@ -129,7 +129,7 @@ struct DivGradXYFunctor<ComplexType<InT>, ComplexType<OutT>> {
const ComplexType<InT> c) {
phi::Array<ComplexType<OutT>, 2> outs;
ComplexType<InT> c_conj(c.real, -c.imag);
ComplexType<InT> out_div_c_conj((b / c).real, -(b / c).imag);
ComplexType<InT> out_div_c_conj(((b / c) / c).real, -((b / c) / c).imag);
outs[0] = a / c_conj;
outs[1] = -a * out_div_c_conj;
return outs;
......@@ -156,7 +156,7 @@ struct DivGradXFunctor<ComplexType<T>> {
template <typename T>
struct DivGradYFunctor {
inline HOSTDEVICE T operator()(const T a, const T b, const T c) const {
return -a * b / c;
return -a * ((b / c) / c);
}
};
......@@ -166,7 +166,7 @@ struct DivGradYFunctor<ComplexType<T>> {
inline HOSTDEVICE ComplexType<T> operator()(const ComplexType<T> a,
const ComplexType<T> b,
const ComplexType<T> c) const {
ComplexType<T> out_div_c_conj((b / c).real, -(b / c).imag);
ComplexType<T> out_div_c_conj(((b / c) / c).real, -((b / c) / c).imag);
return -a * out_div_c_conj;
}
};
......
......@@ -36,7 +36,7 @@ void DivideGradKernel(const Context& dev_ctx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &out, &y};
std::vector<const DenseTensor*> ins = {&dout, &x, &y};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
place,
......@@ -51,7 +51,7 @@ void DivideGradKernel(const Context& dev_ctx,
GetGradXOrYOut<ElementwiseType::kBinary, T>(
dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &out, &y};
std::vector<const DenseTensor*> ins = {&dout, &x, &y};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册