From 4f6e9e3ac30ed4b3489394b79f0b1dd607432b93 Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Wed, 19 Dec 2018 10:48:55 +0800 Subject: [PATCH] teacher student sigmoid loss --- .../teacher_student_sigmoid_loss_op.cc | 256 ++++++++++++++++++ .../teacher_student_sigmoid_loss_op.h | 25 ++ python/paddle/fluid/layers/nn.py | 42 +++ .../test_teacher_student_sigmoid_loss_op.py | 70 +++++ 4 files changed, 393 insertions(+) create mode 100644 paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc create mode 100644 paddle/fluid/operators/teacher_student_sigmoid_loss_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc new file mode 100644 index 0000000000..98eafb9f84 --- /dev/null +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -0,0 +1,256 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/teacher_student_sigmoid_loss_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ(label_dims[1], 1UL, + "The 2nd dimension of " + "Input(Label) should be 1."); + ctx->SetOutputDim("Y", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of + // teacher_student_sigmoid_loss + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class TeacherStudentSigmoidLossGradientOp + : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal."); + PADDLE_ENFORCE_EQ(dy_dims[1], 1, + "The 2nd dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[1], 1, + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(Label) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of + // teacher_student_sigmoid_loss + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class TeacherStudentSigmoidLossOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape [N x 1]," + " where N is the batch size and D is the output. " + "This input is a probability computed by the previous operator, " + "which is almost always the result of a softmax operator."); + AddInput("Label", + "(Tensor), the ground truth which is a 2-D tensor. " + "Label is a Tensor with shape [N x 1]. "); + AddOutput("Y", + "(Tensor, default Tensor), a 2-D tensor with shape " + "[N x 1]. The teacher student sigmoid loss."); + AddAttr("soft_max_up_bound", "fp32, default 15.0").SetDefault(15.0); + AddAttr("soft_max_lower_bound", "fp32, default -15.0") + .SetDefault(-15.0); + AddComment(R"DOC( +TeacherStudentSigmoidLoss Operator. +TeacherStudentSigmoidLoss Operator. + +It's similarity to SigmoidCrossEntropyWithLogits Operator. The difference is that +we add another label(z') to original. + loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x))) + z is click or not + z' is value q of feed_fine + label = {-2, -1, [0, 2]} + when z' is not exist, clk = 0 : label = -2; + when z' is not exist, clk = 1 : label = -1; + when z' is exist , clk = 0 : label = 0 + z'; + when z' is exist , clk = 1 : label = 1 + z'; + +)DOC"); + } +}; + +// template +template +class TeacherStudentSigmoidLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), + "This kernel only runs on CPU."); + + Tensor* y = context.Output("Y"); + const Tensor* x = context.Input("X"); + const Tensor* labels = context.Input("Label"); + T* y_data = y->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + const T* label_data = labels->data(); + int64_t batch_size = x->dims()[0]; + // loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + + // log(1 + exp(-abs(x))) + // z is click or not + // z' is value q of feed_fine + // label = {-2, -1, [0, 2]} + // when z' is not exist, clk = 0 : label = -2; + // when z' is not exist, clk = 1 : label = -1; + // when z' is exist , clk = 0 : label = 0 + z'; + // when z' is exist , clk = 1 : label = 1 + z'; + for (int i = 0; i < batch_size; ++i) { + if (label_data[i] < -1.0) { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) + + log(1.0 + exp(-fabs(x_data[i]))); + } else if (label_data[i] < 0.0) { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) - x_data[i] + + log(1.0 + exp(-fabs(x_data[i]))); + } else if (label_data[i] < 1.0) { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) + + log(1.0 + exp(-fabs(x_data[i]))) + + (x_data[i] > 0 ? x_data[i] : 0.0) - + x_data[i] * label_data[i] + + log(1.0 + exp(-fabs(x_data[i]))); + } else { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) - x_data[i] + + log(1.0 + exp(-fabs(x_data[i]))) + + (x_data[i] > 0 ? x_data[i] : 0.0) - + x_data[i] * (label_data[i] - 1.0) + + log(1.0 + exp(-fabs(x_data[i]))); + } + } + } +}; + +template +class TeacherStudentSigmoidLossGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + const T* x_data = x->data(); + + Tensor* dx = context.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + const Tensor* labels = context.Input("Label"); + const T* label_data = labels->data(); + + T soft_max_up_bound = + static_cast(context.Attr("soft_max_up_bound")); + T soft_max_lower_bound = + static_cast(context.Attr("soft_max_lower_bound")); + + int64_t batch_size = x->dims()[0]; + + const framework::Tensor* dOut = + context.Input(framework::GradVarName("Y")); + + const T* dout_data = dOut->data(); + + for (int i = 0; i < batch_size; ++i) { + T sum_val = x_data[i]; + if (sum_val > soft_max_up_bound) { + sum_val = soft_max_up_bound; + } else { + if (sum_val < soft_max_lower_bound) { + sum_val = soft_max_lower_bound; + } + } + + T pred = 1.0 / (1.0 + exp(-sum_val)); + if (label_data[i] < -1.0) { + dx_data[i] = 0.0 - pred; + } else if (label_data[i] < 0.0) { + dx_data[i] = 1.0 - pred; + } else { + dx_data[i] = label_data[i] - 2.0 * pred; + } + if (sum_val >= soft_max_up_bound || sum_val <= soft_max_lower_bound) { + dx_data[i] = 0; + } + dx_data[i] *= dout_data[i] * -1; + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(teacher_student_sigmoid_loss, + ops::TeacherStudentSigmoidLossOp, + ops::TeacherStudentSigmoidLossOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OPERATOR(teacher_student_sigmoid_loss_grad, + ops::TeacherStudentSigmoidLossGradientOp); + +REGISTER_OP_CPU_KERNEL(teacher_student_sigmoid_loss, + ops::TeacherStudentSigmoidLossOpKernel, + ops::TeacherStudentSigmoidLossOpKernel); + +REGISTER_OP_CPU_KERNEL(teacher_student_sigmoid_loss_grad, + ops::TeacherStudentSigmoidLossGradOpKernel, + ops::TeacherStudentSigmoidLossGradOpKernel); diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.h b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.h new file mode 100644 index 0000000000..77b2760e9c --- /dev/null +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9e6cd1a0ab..68243cf744 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -176,6 +176,7 @@ __all__ = [ 'get_tensor_from_selected_rows', 'lstm', 'psroi_pool', + 'teacher_student_sigmoid_loss', ] kIgnoreIndex = -100 @@ -9184,6 +9185,47 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss +def teacher_student_sigmoid_loss(input, + label, + soft_max_up_bound=15.0, + soft_max_lower_bound=-15.0): + """ + **Teacher Student Log Loss Layer** + + This layer accepts input predictions and target label and returns the + teacher_student loss. + + .. math:: + loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x))) + + Args: + input (Variable|list): a 2-D tensor with shape [N x 1], where N is the + batch size. This input is a probability computed + by the previous operator. + label (Variable|list): the ground truth which is a 2-D tensor with + shape [N x 1], where N is the batch size. + soft_max_up_bound (float): if input > soft_max_up_bound, will be bound + soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound + + Returns: + Variable: A 2-D tensor with shape [N x 1], the teacher_student_sigmoid_loss. + + Examples: + .. code-block:: python + cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label) + """ + helper = LayerHelper('teacher_student_sigmoid_loss', **locals()) + out = helper.create_variable(dtype=input.dtype) + helper.append_op( + type='teacher_student_sigmoid_loss', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out]}, + attrs={"soft_max_lower_bound": float(soft_max_lower_bound), \ + "soft_max_up_bound": float(soft_max_up_bound)}) + return out + + def add_position_encoding(input, alpha, beta, name=None): """ **Add Position Encoding Layer** diff --git a/python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py b/python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py new file mode 100644 index 0000000000..faa5163b32 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from math import log +from math import exp +from op_test import OpTest +from scipy.special import logit +from scipy.special import expit +import unittest + + +class TestTeacherStudentSigmoidLossOp(OpTest): + """ + Test teacher_student_sigmoid_loss with discrete one-hot labels. + """ + + def setUp(self): + """ + ut + """ + self.op_type = "teacher_student_sigmoid_loss" + batch_size = 16 + num_classes = 1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32")), + 'Label': np.random.uniform(0, 2, (batch_size, num_classes)) + .astype("float32") + } + outs = [] + for index, label in enumerate(self.inputs["Label"]): + x = self.inputs["X"][index] + if label < -1.0: + outs.append(max(x, 0.0) + log(1.0 + exp(-abs(x)))) + elif label < 0.0: + outs.append(max(x, 0.0) - x + log(1.0 + exp(-abs(x)))) + elif label < 1.0: + outs.append(max(x, 0.0) + log(1.0 + exp(-abs(x))) + \ + max(x, 0.0) - x * label + log(1.0 + exp(-abs(x)))) + #print "33 python x:", x, "python label:", label, "term1:", max(x, 0.0) + log(1.0 + exp(-abs(x))), "term2:", max(x, 0.0) - x * label + log(1.0 + exp(-abs(x))) + else: + outs.append(max(x, 0.0) - x + log(1.0 + exp(-abs(x))) + \ + max(x, 0.0) - x * (label - 1.0) + log(1.0 + exp(-abs(x)))) + #print "44 python x:", x, "python label:", label, "term1:", max(x, 0.0) - x + log(1.0 + exp(-abs(x))), "term2:", max(x, 0.0) - x * (label - 1.0) + log(1.0 + exp(-abs(x))) + self.outputs = {'Y': np.array(outs)} + + def test_check_output(self): + """ + ut + """ + self.check_output() + + def test_check_grad(self): + """ + ut + """ + self.check_grad(["X"], "Y", numeric_grad_delta=0.005) -- GitLab