提交 5ed07ef1 编写于 作者: Y Yibing Liu

Add more comments and enable the distribution's outside setting

上级 fcff9758
......@@ -31,6 +31,14 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LabelSmoothOp should not be null.");
auto in_dims = ctx->GetInputDim("X");
if (ctx->HasInput("PriorDist")) {
auto noise_dims = ctx->GetInputDim("PriorDist");
auto noise_numel = paddle::framework::product(noise_dims);
PADDLE_ENFORCE(
in_dims[1] == noise_numel,
"The number of elements in Input(PriorDist) must be equal to the "
"dimension of each label.");
}
ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", in_dims);
}
......@@ -40,8 +48,22 @@ class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LabelSmoothOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input label of LabelSmooth operator.");
AddOutput("Out", "The smoothed label of LabelSmooth operator.");
AddInput("X",
"(LoDTensor) The input labels of LabelSmooth operator. This "
"input can be batched labels in one-hot encoding or output from "
"softmax, with shape [N x K], where N is the batch size and K is "
"the number of classes");
AddInput("PriorDist",
"(Tensor, optional)"
"The prior distribution to be added to the smoothed label. It is "
"fixed during training and the number of elements should be equal "
"to the dimension K of each label. Default is uniform "
"distribution and each element will be set to 1/K if not provided "
"in input.")
.AsDispensable();
AddOutput("Out",
"(loDTensor) The smoothed label of LabelSmooth operator. It has"
"the same shape and LoD with the Input(LoDTensor).");
AddAttr<float>("epsilon",
"(float, default 0.0f)"
"The smoothing parameter of LabelSmooth operator.")
......@@ -49,6 +71,28 @@ class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
LabelSmooth Operator.
Label smoothing is a mechanism to regularize the classifier layer. In machine
learning, optimizing the log-likelihood of the correct label directly may
cause two problems. First, it may result in overfitting: if the model learns
to assign full probability to the ground-truth label for each training example,
it is not guaranteed to generalize. Second, it encourages the differences
between the largest logit and all others to become large, reducing the ability
of the model to adapt. Label smoothing is proposed to encourage the model to
be less confident, which replaces the ground-truth label $y$ with the weighted
sum of itselft and some fixed distribution $\mu$,
i.e.
$$
\tilde{y} = (1 - \epsilon) * y + \epsilon * \mu,
$$
where $(1 - \epsilon)$ and $\epsilon$ are the weights respectively, and
$\tilde{y}$ is the smoothed label. Usually uniform distribution is used for
$\mu$. This change in the ground-truth label is called label-smoothing
regularization or LSR.
See more details about label smoothing in https://arxiv.org/abs/1512.00567.
)DOC");
}
};
......
......@@ -26,6 +26,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* in_t = ctx.Input<framework::LoDTensor>("X");
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[1];
out_t->mutable_data<T>(ctx.GetPlace());
......@@ -33,8 +34,15 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
out.device(dev) =
static_cast<T>(1 - epsilon) * in + static_cast<T>(epsilon / label_dim);
if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) =
static_cast<T>(1 - epsilon) * in +
epsilon * dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel()));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
}
}
};
......
......@@ -18,16 +18,20 @@ from op_test import OpTest
class TestLabelSmoothOp(OpTest):
def setUp(self):
def config(self):
self.op_type = "label_smooth"
epsilon = 0.1
batch_size, label_dim = 5, 10
label = np.zeros((batch_size, label_dim)).astype("float64")
nonzero_index = np.random.randint(label_dim, size=(batch_size))
label[np.arange(batch_size), nonzero_index] = 1
smoothed_label = (1 - epsilon) * label + epsilon / label_dim
self.inputs = {'X': label}
self.attrs = {'epsilon': epsilon}
self.epsilon = 0.1
batch_size, self.label_dim = 5, 10
self.label = np.zeros((batch_size, self.label_dim)).astype("float64")
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1
def setUp(self):
self.config()
smoothed_label = (1 - self.epsilon
) * self.label + self.epsilon / self.label_dim
self.inputs = {'X': self.label}
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}
def test_check_output(self):
......@@ -37,5 +41,15 @@ class TestLabelSmoothOp(OpTest):
self.check_grad(["X"], "Out")
class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
def setUp(self):
self.config()
dist = np.random.random((1, self.label_dim))
smoothed_label = (1 - self.epsilon) * self.label + self.epsilon * dist
self.inputs = {'X': self.label, 'PriorDist': dist}
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册