未验证 提交 268156f8 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick] Fix the calculation of y_grad in divide_backward (#53672)

上级 3a53b77e
...@@ -116,7 +116,7 @@ struct DivGradXYFunctor { ...@@ -116,7 +116,7 @@ struct DivGradXYFunctor {
// dy = - dout * out / y // dy = - dout * out / y
phi::Array<OutT, 2> outs; phi::Array<OutT, 2> outs;
outs[0] = a / c; outs[0] = a / c;
outs[1] = -a * b / c; outs[1] = -a * ((b / c) / c);
return outs; return outs;
} }
}; };
...@@ -129,7 +129,7 @@ struct DivGradXYFunctor<ComplexType<InT>, ComplexType<OutT>> { ...@@ -129,7 +129,7 @@ struct DivGradXYFunctor<ComplexType<InT>, ComplexType<OutT>> {
const ComplexType<InT> c) { const ComplexType<InT> c) {
phi::Array<ComplexType<OutT>, 2> outs; phi::Array<ComplexType<OutT>, 2> outs;
ComplexType<InT> c_conj(c.real, -c.imag); ComplexType<InT> c_conj(c.real, -c.imag);
ComplexType<InT> out_div_c_conj((b / c).real, -(b / c).imag); ComplexType<InT> out_div_c_conj(((b / c) / c).real, -((b / c) / c).imag);
outs[0] = a / c_conj; outs[0] = a / c_conj;
outs[1] = -a * out_div_c_conj; outs[1] = -a * out_div_c_conj;
return outs; return outs;
...@@ -156,7 +156,7 @@ struct DivGradXFunctor<ComplexType<T>> { ...@@ -156,7 +156,7 @@ struct DivGradXFunctor<ComplexType<T>> {
template <typename T> template <typename T>
struct DivGradYFunctor { struct DivGradYFunctor {
inline HOSTDEVICE T operator()(const T a, const T b, const T c) const { inline HOSTDEVICE T operator()(const T a, const T b, const T c) const {
return -a * b / c; return -a * ((b / c) / c);
} }
}; };
...@@ -166,7 +166,7 @@ struct DivGradYFunctor<ComplexType<T>> { ...@@ -166,7 +166,7 @@ struct DivGradYFunctor<ComplexType<T>> {
inline HOSTDEVICE ComplexType<T> operator()(const ComplexType<T> a, inline HOSTDEVICE ComplexType<T> operator()(const ComplexType<T> a,
const ComplexType<T> b, const ComplexType<T> b,
const ComplexType<T> c) const { const ComplexType<T> c) const {
ComplexType<T> out_div_c_conj((b / c).real, -(b / c).imag); ComplexType<T> out_div_c_conj(((b / c) / c).real, -((b / c) / c).imag);
return -a * out_div_c_conj; return -a * out_div_c_conj;
} }
}; };
......
...@@ -36,7 +36,7 @@ void DivideGradKernel(const Context& dev_ctx, ...@@ -36,7 +36,7 @@ void DivideGradKernel(const Context& dev_ctx,
DenseTensor* dy) { DenseTensor* dy) {
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &out, &y}; std::vector<const DenseTensor*> ins = {&dout, &x, &y};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, dev_ctx,
place, place,
...@@ -51,7 +51,7 @@ void DivideGradKernel(const Context& dev_ctx, ...@@ -51,7 +51,7 @@ void DivideGradKernel(const Context& dev_ctx,
GetGradXOrYOut<ElementwiseType::kBinary, T>( GetGradXOrYOut<ElementwiseType::kBinary, T>(
dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>()); dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) { } else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&dout, &out, &y}; std::vector<const DenseTensor*> ins = {&dout, &x, &y};
GetGradXOrYOut<ElementwiseType::kTernary, T>( GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>()); dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册