提交 91ba7500 编写于 作者: Z zhoukunsheng

fix type conversion problem in rsqrt functor

上级 082822d4
......@@ -477,7 +477,7 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(-0.5) * dout * out.pow(3);
dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册