未验证 提交 b1e83b33 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix huber loss op attr type, test=develop (#19937)

上级 cc157d59
...@@ -41,7 +41,7 @@ struct HuberLossForward { ...@@ -41,7 +41,7 @@ struct HuberLossForward {
T delta; T delta;
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T>
class HuberLossKernel : public framework::OpKernel<T> { class HuberLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -49,7 +49,7 @@ class HuberLossKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,7 @@ class HuberLossKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>("Y"); auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual"); auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out"); auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.Attr<AttrType>("delta")); auto delta = static_cast<T>(context.Attr<float>("delta"));
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
...@@ -86,7 +86,7 @@ struct HuberLossBackward { ...@@ -86,7 +86,7 @@ struct HuberLossBackward {
T delta; T delta;
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T>
class HuberLossGradKernel : public framework::OpKernel<T> { class HuberLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -94,7 +94,7 @@ class HuberLossGradKernel : public framework::OpKernel<T> { ...@@ -94,7 +94,7 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y")); auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta")); auto delta = static_cast<T>(context.op().Attr<float>("delta"));
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册