diff --git a/paddle/operators/label_smooth_op.cc b/paddle/operators/label_smooth_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c89082f44b360cbd171eccb212674040b8688a46 --- /dev/null +++ b/paddle/operators/label_smooth_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/label_smooth_op.h" + +namespace paddle { +namespace operators { + +class LabelSmoothOp : public framework::OperatorWithKernel { + public: + LabelSmoothOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LabelSmoothOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LabelSmoothOp should not be null."); + auto in_dims = ctx->GetInputDim("X"); + if (ctx->HasInput("PriorDist")) { + auto noise_dims = ctx->GetInputDim("PriorDist"); + auto noise_numel = paddle::framework::product(noise_dims); + PADDLE_ENFORCE( + in_dims[1] == noise_numel, + "The number of elements in Input(PriorDist) must be equal to the " + "dimension of each label."); + } + ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", in_dims); + } +}; + +class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LabelSmoothOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor) The input labels of LabelSmooth operator. This " + "input can be batched labels in one-hot encoding or output from " + "softmax, with shape [N x K], where N is the batch size and K is " + "the number of classes"); + AddInput("PriorDist", + "(Tensor, optional)" + "The prior distribution to be added to the smoothed label. It is " + "fixed during training and the number of elements should be equal " + "to the dimension K of each label. Default is uniform " + "distribution and each element will be set to 1/K if not provided " + "in input.") + .AsDispensable(); + AddOutput("Out", + "(loDTensor) The smoothed label of LabelSmooth operator. It has" + "the same shape and LoD with the Input(LoDTensor)."); + AddAttr("epsilon", + "(float, default 0.0f)" + "The smoothing parameter of LabelSmooth operator.") + .SetDefault(0.0f); + AddComment(R"DOC( +LabelSmooth Operator. + +Label smoothing is a mechanism to regularize the classifier layer. In machine +learning, optimizing the log-likelihood of the correct label directly may +cause two problems. First, it may result in overfitting: if the model learns +to assign full probability to the ground-truth label for each training example, +it is not guaranteed to generalize. Second, it encourages the differences +between the largest logit and all others to become large, reducing the ability +of the model to adapt. Label smoothing is proposed to encourage the model to +be less confident, which replaces the ground-truth label $y$ with the weighted +sum of itself and some fixed distribution $\mu$, i.e. + +$$ + \tilde{y} = (1 - \epsilon) * y + \epsilon * \mu, +$$ + +where $(1 - \epsilon)$ and $\epsilon$ are the weights respectively, and +$\tilde{y}$ is the smoothed label. Usually uniform distribution is used for +$\mu$. This change in the ground-truth label is called label-smoothing +regularization or LSR. + +See more details about label smoothing in https://arxiv.org/abs/1512.00567. + +)DOC"); + } +}; + +class LabelSmoothGradOp : public framework::OperatorWithKernel { + public: + LabelSmoothGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker, + label_smooth_grad, ops::LabelSmoothGradOp); +REGISTER_OP_CPU_KERNEL( + label_smooth, + ops::LabelSmoothKernel, + ops::LabelSmoothKernel); +REGISTER_OP_CPU_KERNEL( + label_smooth_grad, + ops::LabelSmoothGradKernel, + ops::LabelSmoothGradKernel); diff --git a/paddle/operators/label_smooth_op.cu b/paddle/operators/label_smooth_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5a0cec12bc58a56e4b0c3bd6fbc6c4754ef81fa4 --- /dev/null +++ b/paddle/operators/label_smooth_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/label_smooth_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + label_smooth, + ops::LabelSmoothKernel, + ops::LabelSmoothKernel); +REGISTER_OP_CUDA_KERNEL( + label_smooth_grad, + ops::LabelSmoothGradKernel, + ops::LabelSmoothGradKernel); diff --git a/paddle/operators/label_smooth_op.h b/paddle/operators/label_smooth_op.h new file mode 100644 index 0000000000000000000000000000000000000000..87bc9f793e3b4e249142710243c45d51f3a913b2 --- /dev/null +++ b/paddle/operators/label_smooth_op.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LabelSmoothKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* in_t = ctx.Input("X"); + auto* dist_t = ctx.Input("PriorDist"); + auto label_dim = in_t->dims()[1]; + out_t->mutable_data(ctx.GetPlace()); + + auto epsilon = ctx.Attr("epsilon"); + auto out = framework::EigenVector::Flatten(*out_t); + auto in = framework::EigenVector::Flatten(*in_t); + auto& dev = *ctx.template device_context().eigen_device(); + if (dist_t) { + auto dist = framework::EigenVector::Flatten(*dist_t); + out.device(dev) = + static_cast(1 - epsilon) * in + + epsilon * dist.broadcast(Eigen::DSizes(in_t->numel())); + } else { + out.device(dev) = static_cast(1 - epsilon) * in + + static_cast(epsilon / label_dim); + } + } +}; + +template +class LabelSmoothGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* d_in_t = ctx.Output(framework::GradVarName("X")); + d_in_t->mutable_data(ctx.GetPlace()); + + auto d_out = framework::EigenVector::Flatten(*d_out_t); + auto d_in = framework::EigenVector::Flatten(*d_in_t); + + auto epsilon = ctx.Attr("epsilon"); + auto& dev = *ctx.template device_context().eigen_device(); + d_in.device(dev) = static_cast(1 - epsilon) * d_out; + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_label_smooth_op.py b/python/paddle/v2/fluid/tests/test_label_smooth_op.py new file mode 100644 index 0000000000000000000000000000000000000000..19a4df57446c0c83b415909df3e0246bf2716881 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_label_smooth_op.py @@ -0,0 +1,55 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestLabelSmoothOp(OpTest): + def config(self): + self.op_type = "label_smooth" + self.epsilon = 0.1 + batch_size, self.label_dim = 5, 10 + self.label = np.zeros((batch_size, self.label_dim)).astype("float64") + nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) + self.label[np.arange(batch_size), nonzero_index] = 1 + + def setUp(self): + self.config() + smoothed_label = (1 - self.epsilon + ) * self.label + self.epsilon / self.label_dim + self.inputs = {'X': self.label} + self.attrs = {'epsilon': self.epsilon} + self.outputs = {'Out': smoothed_label} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): + def setUp(self): + self.config() + dist = np.random.random((1, self.label_dim)) + smoothed_label = (1 - self.epsilon) * self.label + self.epsilon * dist + self.inputs = {'X': self.label, 'PriorDist': dist} + self.attrs = {'epsilon': self.epsilon} + self.outputs = {'Out': smoothed_label} + + +if __name__ == '__main__': + unittest.main()