未验证 提交 5751b7f4 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Fix sqrt grad (#51045)

* fix sqrt grad

* fix sqrt grad
上级 1794927b
......@@ -278,7 +278,8 @@ void divide_grad(const Tensor& x,
template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto x_grad_tmp = out_grad * 0.5 / out;
// This calculation is important for resnet.
auto x_grad_tmp = (0.5 / out) * out_grad;
set_output<T>(x_grad_tmp, x_grad);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册