huber_loss_op.cc 5.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
yangyaming 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/huber_loss_op.h"
16 17 18
#include <memory>
#include <string>
#include <vector>
Y
yangyaming 已提交
19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

class HuberLossOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

27
  void InferShape(framework::InferShapeContext* ctx) const override {
28 29 30 31
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
                      "Input(X) must be initialized.");
    PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
                      "Input(Y) must be initialized.");
Y
yangyaming 已提交
32

33 34
    auto x_dims = ctx->GetInputDim("X");
    auto y_dims = ctx->GetInputDim("Y");
35

36 37 38
    PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
                      "The rank of Input(X) should be equal to "
                      "the rank of Input(Y).");
39 40 41
    bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
                               framework::contain_unknown_dim(y_dims);
    if (ctx->IsRuntime() || !contain_unknown_dim) {
42 43 44
      PADDLE_ENFORCE_EQ(
          x_dims, y_dims,
          "The Input(X) and Input(Label) should have the same shape.");
P
phlrain 已提交
45
    }
Y
yangyaming 已提交
46

47 48 49
    auto out_dims = y_dims;
    ctx->SetOutputDim("Residual", out_dims);
    ctx->SetOutputDim("Out", out_dims);
50
    ctx->ShareLoD("X", "Out");
Y
yangyaming 已提交
51 52 53 54 55 56
  }
};

template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
57
  void Make() override {
58 59
    AddInput("X",
             "The input value of huber loss op."
60
             "X is a N-D tensor with shape [N_1, N_2,..., N_n].");
61 62
    AddInput("Y",
             "The target value of huber loss op."
63
             "Y is a N-D tensor with shape [N_1, N_2,..., N_n].");
64
    AddOutput("Residual",
65
              "Intermediate tensor to cache residual value between Y and X."
66
              "The shape is same as Input(X) and will be reused in backward.")
Y
yangyaming 已提交
67
        .AsIntermediate();
68
    AddOutput("Out",
69
              "The output N-D tensor with shape [N_1, N_2,..., N_n] "
K
kexinzhao 已提交
70
              "which represents the huber loss.");
Y
yangyaming 已提交
71 72
    AddAttr<AttrType>("delta", "Hyper parameter in huber loss.");
    AddComment(R"DOC(
K
kexinzhao 已提交
73 74
HuberLoss Operator.

75 76
Huber loss is a loss function used in robust regression. We define X as the
input value and Y as the target value. Huber loss can evaluate the fitness of
77
X to Y. Different from MSE loss, Huber loss is more robust for outliers. If the
78
shape of X and Y are [batch_size, 1]. The equation is:
Y
yangyaming 已提交
79

80
$$
Y
yangyaming 已提交
81
Out_{\delta}(X, Y)_i =
82
\begin{cases}
Y
yangyaming 已提交
83 84 85
0.5 * (Y_i - X_i)^2,
\quad |Y_i - X_i| \leq \delta \\
\delta * (|Y_i - X_i| - 0.5 * \delta),
86
\quad otherwise
87
\end{cases}
88
$$
Y
yangyaming 已提交
89

Y
yangyaming 已提交
90 91 92
In the above equation, $Out_\delta(X, Y)_i$, $X_i$ and $Y_i$ represent the ith
element of Out, X and Y.

Y
yangyaming 已提交
93 94 95 96 97 98 99 100
)DOC");
  }
};

class HuberLossGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

101
  void InferShape(framework::InferShapeContext* ctx) const override {
102 103
    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
                      "Input(Out@GRAD) should not be null.");
104 105 106 107 108 109

    auto residual_dims = ctx->GetInputDim("Residual");

    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
110
      ctx->SetOutputDim(x_grad_name, residual_dims);
111 112
    }
    if (ctx->HasOutput(y_grad_name)) {
113
      ctx->SetOutputDim(y_grad_name, residual_dims);
114
    }
Y
yangyaming 已提交
115 116 117
  }
};

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
class HuberLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("huber_loss_grad");
    op->SetInput("Residual", Output("Residual"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
    op->SetAttrMap(Attrs());
    return op;
  }
};

Y
yangyaming 已提交
135 136 137 138
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
139
REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
140
                  ops::HuberLossGradOpDescMaker);
141
REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp);
Q
QI JUN 已提交
142
REGISTER_OP_CPU_KERNEL(
143 144
    huber_loss, ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>,
    ops::HuberLossKernel<paddle::platform::CPUDeviceContext, double>);
Y
yangyaming 已提交
145 146
REGISTER_OP_CPU_KERNEL(
    huber_loss_grad,
147 148
    ops::HuberLossGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::HuberLossGradKernel<paddle::platform::CPUDeviceContext, double>);