未验证 提交 b1e83b33 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix huber loss op attr type, test=develop (#19937)

上级 cc157d59
......@@ -41,7 +41,7 @@ struct HuberLossForward {
T delta;
};
template <typename DeviceContext, typename T, typename AttrType = T>
template <typename DeviceContext, typename T>
class HuberLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -49,7 +49,7 @@ class HuberLossKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
auto delta = static_cast<T>(context.Attr<float>("delta"));
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
......@@ -86,7 +86,7 @@ struct HuberLossBackward {
T delta;
};
template <typename DeviceContext, typename T, typename AttrType = T>
template <typename DeviceContext, typename T>
class HuberLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -94,7 +94,7 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
auto delta = static_cast<T>(context.op().Attr<float>("delta"));
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册