From 570d89ec84296dd46725be4f854808e0f1fb5f1c Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Thu, 6 Dec 2018 16:52:59 +0800 Subject: [PATCH] add bpr_loss operator , test=develop --- paddle/fluid/operators/bpr_loss_op.cc | 149 ++++++++++++++++++ paddle/fluid/operators/bpr_loss_op.h | 142 +++++++++++++++++ python/paddle/fluid/layers/nn.py | 13 ++ .../fluid/tests/unittests/test_bpr_loss_op.py | 53 +++++++ 4 files changed, 357 insertions(+) create mode 100644 paddle/fluid/operators/bpr_loss_op.cc create mode 100644 paddle/fluid/operators/bpr_loss_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_bpr_loss_op.py diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc new file mode 100644 index 00000000000..3e6445dbc26 --- /dev/null +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -0,0 +1,149 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/bpr_loss_op.h" + +namespace paddle { +namespace operators { + +class BprLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label_Pos"), + "Input(Label_Pos) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_Pos_dims = ctx->GetInputDim("Label_Pos"); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ( + rank, label_Pos_dims.size(), + "Input(X) and Input(Label_Pos) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_Pos_dims, 0, rank - 1), + "Input(X) and Input(Label_Pos) shall have the same shape " + "except the last dimension."); + + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of Seq-bpr + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class BprLossGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label_Pos"), + "Input(Label_Pos) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_pos_dims = ctx->GetInputDim("Label_Pos"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ( + label_pos_dims.size(), rank, + "Input(Label_Pos) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_pos_dims, 0, rank - 1), + "The Input(X) and Input(Label_Pos) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_pos_dims[rank - 1], 1, + " the last dimension of Input(Label_Pos) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class BprLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "real number."); + AddInput( + "Label_Pos", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. the last dimension " + "size is 1."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the sequence bpr loss."); + AddComment(R"DOC( +BprLoss Operator. + +This operator belongs to pairwise ranking loss. Label_pos is the desired item. +The loss at a given point in one seesion is defined as: +$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$ + +Learn more details by reading paper . + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPUCtx = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp); +REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel, + ops::BprLossOpKernel); +REGISTER_OP_CPU_KERNEL(bpr_loss_grad, + ops::BprLossGradientOpKernel, + ops::BprLossGradientOpKernel); diff --git a/paddle/fluid/operators/bpr_loss_op.h b/paddle/fluid/operators/bpr_loss_op.h new file mode 100644 index 00000000000..4103686de77 --- /dev/null +++ b/paddle/fluid/operators/bpr_loss_op.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +template +class BprLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels_Pos = ctx.Input("Label_Pos"); + auto* y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); + int rank = x->dims().size(); + + Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); + Tensor labels_Pos_2d = framework::ReshapeToMatrix(*labels_Pos, rank - 1); + Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + + const framework::Tensor* prob = &x_2d; + const framework::Tensor* labels_pos = &labels_Pos_2d; + framework::Tensor* out = &y_2d; + + const int step_size = prob->dims()[0]; + const int class_num = prob->dims()[1]; + const T* prob_data = prob->data(); + T* loss_data = out->data(); + + const int64_t* label_pos_data = labels_pos->data(); + for (int i = 0; i < step_size; ++i) { + int lbl_pos = label_pos_data[i]; + PADDLE_ENFORCE_GE(lbl_pos, 0); + PADDLE_ENFORCE_LT(lbl_pos, class_num); + int index_pos = i * class_num + lbl_pos; + T sum = static_cast(0); + for (int j = 0; j < class_num; j++) { + if (j == lbl_pos) continue; + int index_neg = i * class_num + j; + sum += TolerableValue()(-std::log( + 1.0f + TolerableValue()( + std::exp(prob_data[index_neg] - prob_data[index_pos])))); + } + loss_data[i] = -sum / (class_num - 1); + } + } +}; + +template +class XeGradFunctor { + public: + XeGradFunctor(T* dx, + const T* dy, // NOLINT + const T* x, // NOLINT + const int64_t* label_pos, // NOLINT + size_t num_classes) + : dx_(dx), + dy_(dy), + x_(x), + label_pos_(label_pos), + num_classes_(num_classes) {} + + HOSTDEVICE void operator()(size_t sample_id) { + for (size_t x_offset = sample_id * num_classes_; + x_offset < (sample_id + 1) * num_classes_; ++x_offset) { + dx_[x_offset] = static_cast(0); + } + auto p_index = sample_id * num_classes_ + label_pos_[sample_id]; + for (size_t ni = 0; ni < num_classes_; ni++) { + if (label_pos_[sample_id] == ni) continue; + auto n_index = sample_id * num_classes_ + ni; + auto grad_ = + -dy_[sample_id] / + ((num_classes_ - 1) * + (1.0f + TolerableValue()(std::exp(x_[p_index] - x_[n_index])))); + dx_[p_index] += grad_; + dx_[n_index] -= grad_; + } + } + + private: + T* dx_; + const T* dy_; + const T* x_; + const int64_t* label_pos_; + size_t num_classes_; +}; + +template +class BprLossGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* label_pos = ctx.Input("Label_Pos"); + auto* dx = ctx.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(ctx.GetPlace()); + + int rank = x->dims().size(); + int64_t class_num = x->dims()[rank - 1]; + XeGradFunctor functor(dx_data, dy->data(), x->data(), + label_pos->data(), + static_cast(class_num)); + platform::ForRange for_range( + ctx.template device_context(), + static_cast(dy->numel())); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4df74edfceb..6d05ca8461b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -41,6 +41,7 @@ __all__ = [ 'crf_decoding', 'cos_sim', 'cross_entropy', + 'bpr_loss', 'square_error_cost', 'chunk_eval', 'sequence_conv', @@ -1175,6 +1176,18 @@ def cross_entropy(input, label, soft_label=False, ignore_index=-100): return out +def bpr_loss(input, label_pos): + + helper = LayerHelper('bpr_loss', **locals()) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='bpr_loss', + inputs={'X': [input], + 'Label_Pos': [label_pos]}, + outputs={'Y': [out]}) + return out + + def square_error_cost(input, label): """ **Square error cost layer** diff --git a/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py b/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py new file mode 100644 index 00000000000..7e18913a03b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py @@ -0,0 +1,53 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, randomize_probability + + +class TestBprLossOp1(OpTest): + """Test BprLoss with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "bpr_loss" + batch_size = 3 + class_num = 5 + X = randomize_probability(batch_size, class_num, dtype='float64') + label_pos = np.random.randint( + 0, class_num, (batch_size, 1), dtype="int64") + bpr_loss_result = [] + for i in range(batch_size): + sum = 0.0 + for j in range(class_num): + if j == label_pos[i][0]: + continue + sum += (-np.log(1.0 + np.exp(X[i][j] - X[i][label_pos[i][0]]))) + bpr_loss_result.append(-sum / (class_num - 1)) + bpr_loss = np.asmatrix([[x] for x in bpr_loss_result], dtype="float64") + self.inputs = {"X": X, "Label_Pos": label_pos} + self.outputs = {"Y": bpr_loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + +if __name__ == "__main__": + unittest.main() -- GitLab