提交 e49276e7 编写于 作者: P peizhilin

restore the huber_loss_op

test=develop
上级 01c00b07
......@@ -104,19 +104,15 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
if (out0) {
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
// MSVC not treat it well when partial template arguments were specified
x_grad.device(place) =
out_grad *
residual.unaryExpr(HuberLossBackward<T>(delta, static_cast<T>(-1.0)));
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
}
if (out1) {
out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1);
// MSVC not treat it well when partial template arguments were specified
y_grad.device(place) =
out_grad *
residual.unaryExpr(HuberLossBackward<T>(delta, static_cast<T>(1.0)));
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册