提交 1f423f84 编写于 作者: P peizhilin

fix the huber loss compile issue on windows test=develop

上级 bf518ec8
...@@ -105,14 +105,16 @@ class HuberLossGradKernel : public framework::OpKernel<T> { ...@@ -105,14 +105,16 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0); auto x_grad = EigenVector<T>::Flatten(*out0);
x_grad.device(place) = x_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0)); residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
x_grad.device(place) = out_grad * x_grad;
} }
if (out1) { if (out1) {
out1->mutable_data<T>(context.GetPlace()); out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1); auto y_grad = EigenVector<T>::Flatten(*out1);
y_grad.device(place) = y_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0)); residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
y_grad.device(place) = out_grad * y_grad;
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册