提交 e49276e7 编写于 作者: P peizhilin

restore the huber_loss_op

test=develop
上级 01c00b07
...@@ -104,19 +104,15 @@ class HuberLossGradKernel : public framework::OpKernel<T> { ...@@ -104,19 +104,15 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
if (out0) { if (out0) {
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0); auto x_grad = EigenVector<T>::Flatten(*out0);
// MSVC not treat it well when partial template arguments were specified
x_grad.device(place) = x_grad.device(place) =
out_grad * out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
residual.unaryExpr(HuberLossBackward<T>(delta, static_cast<T>(-1.0)));
} }
if (out1) { if (out1) {
out1->mutable_data<T>(context.GetPlace()); out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1); auto y_grad = EigenVector<T>::Flatten(*out1);
// MSVC not treat it well when partial template arguments were specified
y_grad.device(place) = y_grad.device(place) =
out_grad * out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
residual.unaryExpr(HuberLossBackward<T>(delta, static_cast<T>(1.0)));
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册