提交 fd5199fd 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #3989 from pkuyym/fix-3923-r

Add huber loss operator.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/huber_loss_op.h"
namespace paddle {
namespace operators {
class HuberLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims);
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1].");
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
ctx->SetOutputDim("Residual", x_dims);
ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx->ShareLoD("X", "Out");
}
};
template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HuberLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input value of huber loss op."
"X is a 2-D tensor with shape [batch_size, 1].");
AddInput("Y",
"The target value of huber loss op."
"Y is a 2-D tensor with shape [batch_size, 1].");
AddOutput("Residual",
"Intermediate tensor to cache residual value between Y and X."
"The shape is same as Input(X) and will be reused in backward.")
.AsIntermediate();
AddOutput("Out",
"The output tensor with shape [batch_size, 1] which represents "
"the huber loss.");
AddAttr<AttrType>("delta", "Hyper parameter in huber loss.");
AddComment(R"DOC(
Huber loss is a loss function used in robust regression. We define X as the
input value and Y as the target value. Huber loss can evaluate the fitness of
X to Y. Different from MSE loss, Huber loss is more robust for outliers. The
shape of X and Y are [batch_size, 1]. The equation is:
L_{\delta}(y, f(x)) =
\begin{cases}
0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\
\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise
\end{cases}
)DOC");
}
};
class HuberLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Residual"),
"Input(Residual) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto residual_dims = ctx->GetInputDim("Residual");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
huber_loss_grad, ops::HuberLossGradOp);
REGISTER_OP_CPU_KERNEL(huber_loss,
ops::HuberLossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
huber_loss_grad,
ops::HuberLossGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/huber_loss_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(huber_loss,
ops::HuberLossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
huber_loss_grad,
ops::HuberLossGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
struct HuberLossForward {
HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {}
HOSTDEVICE T operator()(const T& val) const {
T abs_val = std::abs(val);
if (abs_val <= delta) {
return static_cast<T>(0.5) * val * val;
} else {
return delta * (abs_val - static_cast<T>(0.5) * delta);
}
}
T delta;
};
template <typename Place, typename T, typename AttrType = T>
class HuberLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
auto place = context.GetEigenDevice<Place>();
auto x = EigenVector<T>::Flatten(*in0);
auto y = EigenVector<T>::Flatten(*in1);
out0->mutable_data<T>(context.GetPlace());
auto residual = EigenVector<T>::Flatten(*out0);
residual.device(place) = y - x;
out1->mutable_data<T>(context.GetPlace());
auto loss = EigenVector<T>::Flatten(*out1);
loss.device(place) = residual.unaryExpr(HuberLossForward<T>(delta));
}
};
template <typename T>
struct HuberLossBackward {
HOSTDEVICE HuberLossBackward(const T& delta, T sign)
: sign(sign), delta(delta) {}
HOSTDEVICE T operator()(const T& val) const {
T abs_val = std::abs(val);
if (abs_val <= delta) {
return sign * val;
} else {
if (val > 0) {
return sign * delta;
} else {
return -1 * sign * delta;
}
}
}
T sign;
T delta;
};
template <typename Place, typename T, typename AttrType = T>
class HuberLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Residual");
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
auto place = context.GetEigenDevice<Place>();
auto residual = EigenVector<T>::Flatten(*in0);
auto out_grad = EigenVector<T>::Flatten(*in1);
if (out0) {
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
x_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
}
if (out1) {
out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1);
y_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def huber_loss_forward(val, delta):
abs_val = abs(val)
if abs_val <= delta:
return 0.5 * val * val
else:
return delta * (abs_val - 0.5 * delta)
class TestHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'huber_loss'
samples_num = 64
delta = 1.0
self.inputs = {
'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
}
residual = self.inputs['Y'] - self.inputs['X']
loss = np.vectorize(huber_loss_forward)(residual, delta)
self.attrs = {'delta': delta}
self.outputs = {
'Residual': residual,
'Out': loss.reshape((samples_num, 1))
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册