提交 51f11489 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #3987 from pkuyym/fix-3923-c

Add modified huber loss operator
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/modified_huber_loss_op.h"
namespace paddle {
namespace operators {
class ModifiedHuberLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& context) const override {
PADDLE_ENFORCE_NOT_NULL(context.InputVar("X"), "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(context.InputVar("Y"), "Y must be initialized.");
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
"The shape of X and Y must be the same.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2.");
PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1.");
context.Output<framework::LoDTensor>("IntermediateVal")->Resize(x->dims());
context.Output<framework::LoDTensor>("Out")->Resize({x->dims()[0], 1});
}
};
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ModifiedHuberLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input tensor of modified huber loss op."
"X is 2-D tensor with shape [batch_size, 1].");
AddInput("Y",
"The target labels of modified huber loss op."
"The shape of Y is same as X. Values of Y must be 0 or 1.");
AddOutput("IntermediateVal",
"Variable to save intermediate result which will be reused in "
"backward processing.")
.AsIntermediate();
AddOutput("Out", "Classification loss for X.");
AddComment(R"DOC(
Modified huber loss is used in binary classification problem. The shape of
input X and target Y are both [N, 1] and so is the shape of output loss.
Since target Y is not differentiable, cacluating gradient for Y is illegal.
The formulation of modified huber loss is:
L(y, f(x)) = max(0, 1 - yf(x))^2 for yf(x) >= -1,
-4yf(x) otherwise.
Make sure the values of target label Y are in {0, 1} here. The operator will
scale values of Y to {-1, +1} when computing losses and gradients.
)DOC");
}
};
class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* intermediate_val = context.Input<Tensor>("IntermediateVal");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized.");
PADDLE_ENFORCE_NOT_NULL(intermediate_val,
"Intermediate value must not be null.");
PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@Grad) must not be null.");
PADDLE_ENFORCE_EQ(
intermediate_val->dims(), x->dims(),
"The shape of X and intermediate value must be the same.");
PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims(),
"The shape of Input(Out@Grad) and X must be the same.");
if (x_grad) x_grad->Resize(x->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(modified_huber_loss, ops::ModifiedHuberLossOp,
ops::ModifiedHuberLossOpMaker, modified_huber_loss_grad,
ops::ModifiedHuberLossGradOp);
REGISTER_OP_CPU_KERNEL(
modified_huber_loss,
ops::ModifiedHuberLossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(modified_huber_loss_grad,
ops::ModifiedHuberLossGradCPUKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/for_each.h>
#include <thrust/tuple.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/modified_huber_loss_op.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
struct ModifiedHuberLossBackward {
template <typename Tuple>
HOSTDEVICE void operator()(Tuple t) const {
auto inter_val = thrust::get<1>(t);
auto y_val = thrust::get<2>(t);
auto out_grad = thrust::get<3>(t);
if (inter_val < -1) {
thrust::get<0>(t) = -4 * (2 * y_val - 1) * out_grad;
} else if (inter_val < 1) {
thrust::get<0>(t) = -2 * (1 - inter_val) * (2 * y_val - 1) * out_grad;
} else {
thrust::get<0>(t) = 0;
}
}
};
template <typename T>
class ModifiedHuberLossGradGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y");
auto* in1 = context.Input<Tensor>("IntermediateVal");
auto* in2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
if (out0) {
auto counts = framework::product(in1->dims());
auto y_ptr = thrust::device_pointer_cast(in0->data<T>());
auto inter_val_ptr = thrust::device_pointer_cast(in1->data<T>());
auto out_grad_ptr = thrust::device_pointer_cast(in2->data<T>());
thrust::device_ptr<T> x_grad_ptr(
out0->mutable_data<T>(context.GetPlace()));
auto iter_begin = thrust::make_zip_iterator(
thrust::make_tuple(x_grad_ptr, inter_val_ptr, y_ptr, out_grad_ptr));
auto iter_end = thrust::make_zip_iterator(
thrust::make_tuple(x_grad_ptr + counts, inter_val_ptr + counts,
y_ptr + counts, out_grad_ptr + counts));
thrust::for_each(iter_begin, iter_end, ModifiedHuberLossBackward());
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
modified_huber_loss,
ops::ModifiedHuberLossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(modified_huber_loss_grad,
ops::ModifiedHuberLossGradGPUKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
struct CheckLabelValue {
HOSTDEVICE T operator()(const T& val) const {
PADDLE_ASSERT(val == static_cast<T>(0) || val == static_cast<T>(1));
}
};
template <typename T>
struct ModifiedHuberLossForward {
HOSTDEVICE T operator()(const T& val) const {
if (val < -1) {
return -4 * val;
} else if (val < 1) {
return (1 - val) * (1 - val);
} else {
return static_cast<T>(0);
}
}
};
template <typename Place, typename T>
class ModifiedHuberLossKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<framework::LoDTensor>("IntermediateVal");
auto* out1 = context.Output<framework::LoDTensor>("Out");
out0->mutable_data<T>(context.GetPlace());
out1->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
auto x = EigenVector<T>::Flatten(*in0);
auto y = EigenVector<T>::Flatten(*in1);
// make sure value's of Y in {0, 1}
y.unaryExpr(CheckLabelValue<T>());
auto inter_val = EigenVector<T>::Flatten(*out0);
// scale y to {-1, +1} and compute x * y
inter_val.device(place) = x * (2 * y - static_cast<T>(1));
auto loss = EigenVector<T>::Flatten(*out1);
loss.device(place) = inter_val.unaryExpr(ModifiedHuberLossForward<T>());
}
};
// CPU backward kernel
template <typename T>
class ModifiedHuberLossGradCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y");
auto* in1 = context.Input<framework::LoDTensor>("IntermediateVal");
auto* in2 =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* out0 =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (out0) {
const T* y_ptr = in0->data<T>();
const T* inter_val_ptr = in1->data<T>();
const T* out_grad_ptr = in2->data<T>();
size_t counts = static_cast<size_t>(framework::product(in1->dims()));
T* x_grad_ptr = out0->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < counts; ++i) {
if (inter_val_ptr[i] < -1) {
x_grad_ptr[i] = -4 * (2 * y_ptr[i] - 1) * out_grad_ptr[i];
} else if (inter_val_ptr[i] < 1) {
x_grad_ptr[i] = -2 * (1 - inter_val_ptr[i]) * (2 * y_ptr[i] - 1) *
out_grad_ptr[i];
} else {
x_grad_ptr[i] = 0;
}
}
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def modified_huber_loss_forward(val):
if val < -1:
return -4 * val
elif val < 1:
return (1 - val) * (1 - val)
else:
return 0
class TestModifiedHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'modified_huber_loss'
samples_num = 32
self.inputs = {
'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'),
'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1))
}
product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1)
loss = np.vectorize(modified_huber_loss_forward)(product_res)
self.outputs = {
'IntermediateVal': product_res,
'Out': loss.reshape((samples_num, 1))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册