modified_huber_loss_op.cc 6.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/modified_huber_loss_op.h"
16
#include <memory>
17 18 19 20 21 22 23 24

namespace paddle {
namespace operators {

class ModifiedHuberLossOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

25
  void InferShape(framework::InferShapeContext* ctx) const override {
26 27
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ModifiedHuberLoss");
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ModifiedHuberLoss");
28

Q
Qiao Longfei 已提交
29 30
    auto x_dims = ctx->GetInputDim("X");
    auto y_dims = ctx->GetInputDim("Y");
31

32 33 34 35 36
    PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
                                            "Input(input) rank should be 2, "
                                            "but received input rank(%d) != 2",
                                            x_dims.size()));

P
phlrain 已提交
37 38
    if (ctx->IsRuntime() ||
        (framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
39 40 41 42 43 44
      PADDLE_ENFORCE_EQ(
          x_dims, y_dims,
          platform::errors::InvalidArgument(
              "The Input(input) and Input(label) should have the same "
              "shape, but received input shape [%s] != label shape [%s]",
              x_dims, y_dims));
P
phlrain 已提交
45 46 47
    }

    if (ctx->IsRuntime()) {
48 49 50 51 52
      PADDLE_ENFORCE_EQ(x_dims[1], 1,
                        platform::errors::InvalidArgument(
                            "The second dimension of Input(input) should be 1, "
                            "but received second dimension of input (%d) != 1",
                            x_dims[1]));
P
phlrain 已提交
53
    }
54

Q
Qiao Longfei 已提交
55 56
    ctx->SetOutputDim("IntermediateVal", x_dims);
    ctx->SetOutputDim("Out", {x_dims[0], 1});
57 58 59 60 61
  }
};

class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
62
  void Make() override {
63
    AddInput("X",
K
kexinzhao 已提交
64
             "The input tensor of modified huber loss op. "
65 66
             "X is 2-D tensor with shape [batch_size, 1].");
    AddInput("Y",
K
kexinzhao 已提交
67 68
             "The target labels of modified huber loss op. "
             "The shape of Y is the same as X. Values of Y must be 0 or 1.");
69
    AddOutput("IntermediateVal",
Y
yangyaming 已提交
70 71 72
              "Variable to save intermediate result which will be reused in "
              "backward processing.")
        .AsIntermediate();
73
    AddOutput("Out", "Classification loss for X.");
Y
yangyaming 已提交
74
    AddComment(R"DOC(
K
kexinzhao 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
Modified Huber Loss Operator.

This operator is used in binary classification problem. The shape of
input X and target Y are both [N, 1] and so is the shape of the output loss.
Since target Y is not differentiable, calculating gradient for Y is illegal.
The formula of modified huber loss is:

$$
L(y, f(x)) = 
\begin{cases}
(\max(0, 1 - yf(x)))^2,  \text{if} \  yf(x) >= -1    \\
             -4yf(x),    \quad \text{otherwise}
\end{cases}
$$

Make sure the values of target label Y are in {0, 1} here. This operator will
91
scale values of Y to {-1, +1} when computing losses and gradients.
K
kexinzhao 已提交
92

Y
yangyaming 已提交
93
)DOC");
94 95 96 97 98 99 100
  }
};

class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

101
  void InferShape(framework::InferShapeContext* ctx) const override {
102 103 104 105 106
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ModifiedHuberLossGrad");
    OP_INOUT_CHECK(ctx->HasInput("IntermediateVal"), "Input", "IntermediateVal",
                   "ModifiedHuberLossGrad");
    OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "ModifiedHuberLossGrad");
Q
Qiao Longfei 已提交
107

108
    auto y_dims = ctx->GetInputDim("Y");
Q
Qiao Longfei 已提交
109 110
    auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
    auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
111

P
phlrain 已提交
112 113
    if (ctx->IsRuntime()) {
      PADDLE_ENFORCE_EQ(
114
          intermediate_dims, y_dims,
115 116 117 118 119 120 121 122 123 124 125 126 127 128
          platform::errors::InvalidArgument(
              "The shape of Intermediate variable which will be reused in "
              "backward processing should the same as "
              "the shape of Input(label), but received Intermediate variable "
              "shape [%s] != label shape [%s]",
              intermediate_dims, y_dims));

      PADDLE_ENFORCE_EQ(
          out_grad_dims, y_dims,
          platform::errors::InvalidArgument(
              "The shape of output gradient should be the same as "
              "the shape of Input(label), but received the output gradient "
              "shape [%s] != label shape [%s]",
              out_grad_dims, y_dims));
P
phlrain 已提交
129
    }
130

Q
Qiao Longfei 已提交
131
    if (ctx->HasOutput(framework::GradVarName("X"))) {
132
      ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
Q
Qiao Longfei 已提交
133
    }
134 135 136
  }
};

H
hong 已提交
137 138
template <typename T>
class ModifiedHuberLossGradOpMaker : public framework::SingleGradOpMaker<T> {
139
 public:
H
hong 已提交
140
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
141 142

 protected:
143
  void Apply(GradOpPtr<T> op) const override {
144
    op->SetType("modified_huber_loss_grad");
H
hong 已提交
145 146 147 148 149
    op->SetInput("Y", this->Input("Y"));
    op->SetInput("IntermediateVal", this->Output("IntermediateVal"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
150 151 152
  }
};

153 154 155 156
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
157 158 159 160 161
REGISTER_OPERATOR(
    modified_huber_loss, ops::ModifiedHuberLossOp,
    ops::ModifiedHuberLossOpMaker,
    ops::ModifiedHuberLossGradOpMaker<paddle::framework::OpDesc>,
    ops::ModifiedHuberLossGradOpMaker<paddle::imperative::OpBase>);
162
REGISTER_OPERATOR(modified_huber_loss_grad, ops::ModifiedHuberLossGradOp);
163 164 165

REGISTER_OP_CPU_KERNEL(
    modified_huber_loss,
Q
QI JUN 已提交
166
    ops::ModifiedHuberLossKernel<paddle::platform::CPUDeviceContext, float>);
167 168
REGISTER_OP_CPU_KERNEL(modified_huber_loss_grad,
                       ops::ModifiedHuberLossGradCPUKernel<float>);