diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2722ea078ebdf9a88fe2286fb4050fca652ffb7f..fd4cf92d85d5daa891d602d4365122c870920bba 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -66,6 +66,7 @@ paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)) +paddle.fluid.layers.bpr_loss ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)) diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9258d7c7e83122149c7cbc42e4a4bdd84903ce67 --- /dev/null +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -0,0 +1,145 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/bpr_loss_op.h" + +namespace paddle { +namespace operators { + +class BprLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of Seq-bpr + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class BprLossGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + " the last dimension of Input(Label) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class BprLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "real number."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. the last dimension " + "size is 1."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the sequence bpr loss."); + AddComment(R"DOC( +Bayesian Personalized Ranking Loss Operator. + +This operator belongs to pairwise ranking loss. Label is the desired item. +The loss at a given point in one session is defined as: +$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$ + +Learn more details by reading paper (https://arxiv.org/abs/1511.06939) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPUCtx = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp); +REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel, + ops::BprLossOpKernel); +REGISTER_OP_CPU_KERNEL(bpr_loss_grad, + ops::BprLossGradientOpKernel, + ops::BprLossGradientOpKernel); diff --git a/paddle/fluid/operators/bpr_loss_op.h b/paddle/fluid/operators/bpr_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e223be7af82146e7c69c7c5aab8f08d0fe0d1710 --- /dev/null +++ b/paddle/fluid/operators/bpr_loss_op.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +/*Todo: + *Find a way to adapt TolerableValue, using blas or eigen. + */ +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +template +class BprLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); + auto* y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); + int rank = x->dims().size(); + + Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); + Tensor labels_2d = framework::ReshapeToMatrix(*label, rank - 1); + Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + + const framework::Tensor* logits = &x_2d; + const framework::Tensor* labels = &labels_2d; + framework::Tensor* out = &y_2d; + + const int step_size = logits->dims()[0]; + const int class_num = logits->dims()[1]; + const T* logits_data = logits->data(); + T* loss_data = out->data(); + + const int64_t* label_data = labels->data(); + for (int i = 0; i < step_size; ++i) { + int lbl_pos = label_data[i]; + PADDLE_ENFORCE_GE(lbl_pos, 0); + PADDLE_ENFORCE_LT(lbl_pos, class_num); + int index_pos = i * class_num + lbl_pos; + T sum = static_cast(0); + for (int j = 0; j < class_num; j++) { + if (j == lbl_pos) continue; + int index_neg = i * class_num + j; + sum += TolerableValue()(-std::log( + 1.0f + TolerableValue()(std::exp(logits_data[index_neg] - + logits_data[index_pos])))); + } + loss_data[i] = -sum / (class_num - 1); + } + } +}; + +template +class BprLossGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* label = ctx.Input("Label"); + auto* dx = ctx.Output(framework::GradVarName("X")); + + const int step_size = x->dims()[0]; + const int num_classes = x->dims()[1]; + T* dx_data = dx->mutable_data(ctx.GetPlace()); + const T* dy_data = dy->data(); + const T* x_data = x->data(); + const int64_t* label_data = label->data(); + + for (size_t sample_id = 0; sample_id < step_size; sample_id++) { + for (size_t x_offset = sample_id * num_classes; + x_offset < (sample_id + 1) * num_classes; x_offset++) { + dx_data[x_offset] = static_cast(0); + } + auto p_index = sample_id * num_classes + label_data[sample_id]; + for (size_t ni = 0; ni < num_classes; ni++) { + if (label_data[sample_id] == ni) continue; + auto n_index = sample_id * num_classes + ni; + auto grad_ = -dy_data[sample_id] / + ((num_classes - 1) * + (1.0f + TolerableValue()(std::exp(x_data[p_index] - + x_data[n_index])))); + dx_data[p_index] += grad_; + dx_data[n_index] -= grad_; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fac7538a6ada56e3722e2540519df863bf7cac71..e25eaaa9fda6add9d8e81d9e6bdfb711cee3648e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -41,6 +41,7 @@ __all__ = [ 'crf_decoding', 'cos_sim', 'cross_entropy', + 'bpr_loss', 'square_error_cost', 'chunk_eval', 'sequence_conv', @@ -1348,6 +1349,44 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): return out +def bpr_loss(input, label, name=None): + """ + Bayesian Personalized Ranking Loss Operator. + + This operator belongs to pairwise ranking loss. Label is the desired item. + The loss at a given point in one session is defined as: + $Y[i] = -\frac{1}{N_{i}-1} * \sum_{0\le j(https://arxiv.org/abs/1511.06939) + + Args: + input (Variable|list): a 2-D tensor with shape [N x D], where N is the + batch size and D is the number of classes. + This input is not probability but logits. + label (Variable|list): the ground truth which is a 2-D tensor. `label` + is a tensor with shape [N x 1]. + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically. Default: None. + Returns: + A 2-D tensor with shape [N x 1], the bpr loss. + + Examples: + .. code-block:: python + + cost = fluid.layers.bpr_loss(input=predict, label=label) + """ + + helper = LayerHelper('bpr_loss', **locals()) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='bpr_loss', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out]}) + return out + + def square_error_cost(input, label): """ **Square error cost layer** diff --git a/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py b/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dc5fbd237d17f2d4e45b06e5806fff5cbf58fe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py @@ -0,0 +1,52 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, randomize_probability + + +class TestBprLossOp1(OpTest): + """Test BprLoss with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "bpr_loss" + batch_size = 40 + class_num = 5 + X = randomize_probability(batch_size, class_num, dtype='float64') + label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64") + bpr_loss_result = [] + for i in range(batch_size): + sum = 0.0 + for j in range(class_num): + if j == label[i][0]: + continue + sum += (-np.log(1.0 + np.exp(X[i][j] - X[i][label[i][0]]))) + bpr_loss_result.append(-sum / (class_num - 1)) + bpr_loss = np.asmatrix([[x] for x in bpr_loss_result], dtype="float64") + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": bpr_loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index be51fb06a37a376f6f410336184c95981ded35dc..10e8bb5a86691d8654c5ae48794e49f30f47500d 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -846,6 +846,15 @@ class TestBook(unittest.TestCase): out = layers.cross_entropy(x, label, False, 4) self.assertIsNotNone(out) + def test_bpr_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[30, 10], dtype="float32") + label = layers.data(name="label", shape=[30, 1], dtype="int32") + out = layers.bpr_loss(x, label) + self.assertIsNotNone(out) + print(str(program)) + def test_expand(self): program = Program() with program_guard(program):