From 96500af64b07913b8cd3be09dceb8fe02db86168 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 14 Sep 2017 11:12:37 +0800 Subject: [PATCH] add rank_loss operator --- paddle/operators/rank_loss_op.cc | 103 +++++++++++++++++++++++++++++++ paddle/operators/rank_loss_op.cu | 22 +++++++ paddle/operators/rank_loss_op.h | 90 +++++++++++++++++++++++++++ paddle/pybind/pybind.cc | 1 + 4 files changed, 216 insertions(+) create mode 100644 paddle/operators/rank_loss_op.cc create mode 100644 paddle/operators/rank_loss_op.cu create mode 100644 paddle/operators/rank_loss_op.h diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc new file mode 100644 index 0000000000..14cddb609f --- /dev/null +++ b/paddle/operators/rank_loss_op.cc @@ -0,0 +1,103 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/rank_loss_op.h" + +namespace paddle { +namespace operators { + +class RankLossOp : public framework::OperatorWithKernel { + public: + RankLossOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // input check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("P"), "Input(P) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oi"), "Input(Oi) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oj"), "Input(Oj) shouldn't be null"); + auto p_dims = ctx.Input("P")->dims(); + auto oi_dims = ctx.Input("Oi")->dims(); + auto oj_dims = ctx.Input("Oj")->dims(); + PADDLE_ENFORCE_EQ(oi_dims, oj_dims, + "Input(Oi) and Input(Oj) must have the same size"); + PADDLE_ENFORCE_EQ( + p_dims, oi_dims, + "Input(P) must have the same size with Input(Oi) & Input(Oj)"); + ctx.Output("Out")->Resize(p_dims); + } +}; + +class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RankLossOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("P", "The first input of RankLoss operator."); + AddInput("Oi", "The second input of RankLoss operator."); + AddInput("Oj", "The third input of RankLoss operator."); + AddOutput("Out", "The output tensor of RankLoss operator."); + AddComment(R"DOC(RankLoss operator + +A rank loss operator for learning to rank (LTR) task. This operator contains +three inputs: P, Oi, and Oj, and the rank cost can be expressed as + +\f[ + C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ + o_{i,j} = o_i - o_j \\ + \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} +\f] + +[1]. Chris Burges, Tal Shaked, Erin Renshaw, et al. Learning to + Rank useing Gradient Descent. +)DOC"); + } +}; + +class RankLossGradOp : public framework::OperatorWithKernel { + public: + RankLossGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("P"), "Input(P) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oi"), "Input(Oi) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oj"), "Input(Oj) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto dims = ctx.Input("P")->dims(); + ctx.Output(framework::GradVarName("P"))->Resize(dims); + ctx.Output(framework::GradVarName("Oi"))->Resize(dims); + ctx.Output(framework::GradVarName("Oj"))->Resize(dims); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(rank_loss, ops::RankLossOp, ops::RankLossOpMaker, rank_loss_grad, + ops::RankLossGradOp); +REGISTER_OP_CPU_KERNEL(rank_loss, + ops::RankLossKernel); +REGISTER_OP_CPU_KERNEL( + rank_loss_grad, ops::RankLossGradKernel); diff --git a/paddle/operators/rank_loss_op.cu b/paddle/operators/rank_loss_op.cu new file mode 100644 index 0000000000..779588ff36 --- /dev/null +++ b/paddle/operators/rank_loss_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/rank_loss_op.h" + +REGISTER_OP_GPU_KERNEL( + rank_loss, + paddle::operators::RankLossKernel); +REGISTER_OP_GPU_KERNEL( + rank_loss_grad, + paddle::operators::RankLossGradKernel); diff --git a/paddle/operators/rank_loss_op.h b/paddle/operators/rank_loss_op.h new file mode 100644 index 0000000000..d21871107a --- /dev/null +++ b/paddle/operators/rank_loss_op.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class RankLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out = ctx.Output("Out"); + auto* p_t = ctx.Input("P"); + auto* oi_t = ctx.Input("Oi"); + auto* oj_t = ctx.Input("Oj"); + out->mutable_data(ctx.GetPlace()); + + auto& dev = ctx.GetEigenDevice(); + auto out_eig = framework::EigenVector::Flatten(*out); + auto p_eig = framework::EigenVector::Flatten(*p_t); + auto oi_eig = framework::EigenVector::Flatten(*oi_t); + auto oj_eig = framework::EigenVector::Flatten(*oj_t); + + framework::Tensor o_t; + o_t.Resize(oi_t->dims()); + o_t.mutable_data(ctx.GetPlace()); + auto o_eig = framework::EigenVector::Flatten(o_t); + o_eig.device(dev) = oi_eig - oj_eig; + + out_eig.device(dev) = (1. + (o_eig).exp()).log() - p_eig * o_eig; + } +}; + +template +class RankLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_oi = ctx.Output(framework::GradVarName("Oi")); + auto* d_oj = ctx.Output(framework::GradVarName("Oj")); + auto* d_p = ctx.Output(framework::GradVarName("P")); + + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* p_t = ctx.Input("P"); + auto* oi_t = ctx.Input("Oi"); + auto* oj_t = ctx.Input("Oj"); + + d_oi->mutable_data(ctx.GetPlace()); + d_oj->mutable_data(ctx.GetPlace()); + d_p->mutable_data(ctx.GetPlace()); + + auto& dev = ctx.GetEigenDevice(); + auto d_out_eig = framework::EigenVector::Flatten(*d_out); + auto p_eig = framework::EigenVector::Flatten(*p_t); + auto oi_eig = framework::EigenVector::Flatten(*oi_t); + auto oj_eig = framework::EigenVector::Flatten(*oj_t); + + auto d_oi_eig = framework::EigenVector::Flatten(*d_oi); + auto d_oj_eig = framework::EigenVector::Flatten(*d_oj); + + framework::Tensor o_t; + o_t.Resize(oi_t->dims()); + o_t.mutable_data(ctx.GetPlace()); + auto o_eig = framework::EigenVector::Flatten(o_t); + o_eig.device(dev) = oi_eig - oj_eig; + + // dOi & dOj + d_oi_eig.device(dev) = + d_out_eig * (o_eig.exp() / (1. + o_eig.exp()) - p_eig); + d_oj_eig.device(dev) = -d_oi_eig; + // dP + framework::EigenVector::Flatten(*d_p).device(dev) = -o_eig; + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index ef62d6e997..1805a830b3 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,6 +56,7 @@ USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); USE_OP(reshape); +USE_OP(rank_loss); namespace paddle { namespace framework { -- GitLab