diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index 66571bd9a6378463d011d17f28532620e787588b..fd3ac86939fcb47f79d538c4047719473cc0cf70 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -28,18 +28,21 @@ class RankLossOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("P"), "Input(P) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oi"), "Input(Oi) shouldn't be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oj"), "Input(Oj) shouldn't be null"); - auto p_dims = ctx.Input("P")->dims(); - auto oi_dims = ctx.Input("Oi")->dims(); - auto oj_dims = ctx.Input("Oj")->dims(); - PADDLE_ENFORCE_EQ(oi_dims, oj_dims, - "Input(Oi) and Input(Oj) must have the same size"); - PADDLE_ENFORCE_EQ( - p_dims, oi_dims, - "Input(P) must have the same size with Input(Oi) & Input(Oj)"); - ctx.Output("Out")->Resize(p_dims); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), + "Input(Left) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), + "Input(Right) shouldn't be null"); + auto label_dims = ctx.Input("Label")->dims(); + auto left_dims = ctx.Input("Left")->dims(); + auto right_dims = ctx.Input("Right")->dims(); + PADDLE_ENFORCE((label_dims.size() == 1) && (left_dims.size() == 1) && + (right_dims.size() == 1), + "The rank of all inputs must be 1."); + PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), + "All inputs must have the same size"); + ctx.Output("Out")->Resize(label_dims); } }; @@ -48,14 +51,23 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { RankLossOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("P", "The desired target values for posteriors."); - AddInput("Oi", "The model output for item i."); - AddInput("Oj", "The model output for item j."); - AddOutput("Out", "The output tensor of RankLoss operator."); + AddInput("Label", + "The label indicating A ranked higher than B or not, 1-D tensor."); + AddInput("Left", "The output of RankNet for doc A, 1-D tensor."); + AddInput("Right", "The output of RankNet for doc B, 1-D tensor"); + AddOutput("Out", "The output loss of RankLoss operator, 1-D tensor."); AddComment(R"DOC(RankLoss operator -A rank loss operator for learning to rank (LTR) task. This operator contains -three inputs: P, Oi, and Oj, and the rank cost can be expressed as +Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with +one training sample consisting of a pair of doc A and B, and the label P +indicating that A is ranked higher than B or not: + +P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of +the input pair. + +The RankLoss operator contains three inputs: Left (o_i), Right (o_j) and Label +(P_{i,j}), which represent the output of RankNet for two docs and the label +respectively, and yields the rank loss C_{i,j} by following the expression \f[ C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ @@ -63,10 +75,11 @@ three inputs: P, Oi, and Oj, and the rank cost can be expressed as \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} \f] -A detailed explanation about these notations can be found in +The operator can take inputs of one sample or in batch. [1]. Chris Burges, Tal Shaked, Erin Renshaw, et al. Learning to - Rank useing Gradient Descent. + Rank using Gradient Descent. + http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf )DOC"); } }; @@ -81,15 +94,25 @@ class RankLossGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("P"), "Input(P) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oi"), "Input(Oi) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oj"), "Input(Oj) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), + "Input(Left) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), + "Input(Right) shouldn't be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); - auto dims = ctx.Input("P")->dims(); - ctx.Output(framework::GradVarName("P"))->Resize(dims); - ctx.Output(framework::GradVarName("Oi"))->Resize(dims); - ctx.Output(framework::GradVarName("Oj"))->Resize(dims); + auto dims = ctx.Input("Left")->dims(); + auto *left_grad = + ctx.Output(framework::GradVarName("Left")); + auto *right_grad = + ctx.Output(framework::GradVarName("Right")); + if (left_grad) { + left_grad->Resize(dims); + } + if (right_grad) { + right_grad->Resize(dims); + } } }; diff --git a/paddle/operators/rank_loss_op.h b/paddle/operators/rank_loss_op.h index d21871107a1c68b117eff5457858fbd145749366..9776d123fe4b0cb0cd16a15770fcf42a966fa011 100644 --- a/paddle/operators/rank_loss_op.h +++ b/paddle/operators/rank_loss_op.h @@ -24,25 +24,20 @@ template class RankLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* out = ctx.Output("Out"); - auto* p_t = ctx.Input("P"); - auto* oi_t = ctx.Input("Oi"); - auto* oj_t = ctx.Input("Oj"); - out->mutable_data(ctx.GetPlace()); + auto* out_t = ctx.Output("Out"); + auto* label_t = ctx.Input("Label"); + auto* left_t = ctx.Input("Left"); + auto* right_t = ctx.Input("Right"); + out_t->mutable_data(ctx.GetPlace()); - auto& dev = ctx.GetEigenDevice(); - auto out_eig = framework::EigenVector::Flatten(*out); - auto p_eig = framework::EigenVector::Flatten(*p_t); - auto oi_eig = framework::EigenVector::Flatten(*oi_t); - auto oj_eig = framework::EigenVector::Flatten(*oj_t); - - framework::Tensor o_t; - o_t.Resize(oi_t->dims()); - o_t.mutable_data(ctx.GetPlace()); - auto o_eig = framework::EigenVector::Flatten(o_t); - o_eig.device(dev) = oi_eig - oj_eig; + auto out = framework::EigenVector::Flatten(*out_t); + auto label = framework::EigenVector::Flatten(*label_t); + auto left = framework::EigenVector::Flatten(*left_t); + auto right = framework::EigenVector::Flatten(*right_t); - out_eig.device(dev) = (1. + (o_eig).exp()).log() - p_eig * o_eig; + auto& dev = ctx.GetEigenDevice(); + out.device(dev) = + (1. + (left - right).exp()).log() - label * (left - right); } }; @@ -50,40 +45,35 @@ template class RankLossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* d_oi = ctx.Output(framework::GradVarName("Oi")); - auto* d_oj = ctx.Output(framework::GradVarName("Oj")); - auto* d_p = ctx.Output(framework::GradVarName("P")); - - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* p_t = ctx.Input("P"); - auto* oi_t = ctx.Input("Oi"); - auto* oj_t = ctx.Input("Oj"); + auto* d_left_t = + ctx.Output(framework::GradVarName("Left")); + auto* d_right_t = + ctx.Output(framework::GradVarName("Right")); - d_oi->mutable_data(ctx.GetPlace()); - d_oj->mutable_data(ctx.GetPlace()); - d_p->mutable_data(ctx.GetPlace()); + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* label_t = ctx.Input("Label"); + auto* left_t = ctx.Input("Left"); + auto* right_t = ctx.Input("Right"); auto& dev = ctx.GetEigenDevice(); - auto d_out_eig = framework::EigenVector::Flatten(*d_out); - auto p_eig = framework::EigenVector::Flatten(*p_t); - auto oi_eig = framework::EigenVector::Flatten(*oi_t); - auto oj_eig = framework::EigenVector::Flatten(*oj_t); - - auto d_oi_eig = framework::EigenVector::Flatten(*d_oi); - auto d_oj_eig = framework::EigenVector::Flatten(*d_oj); - - framework::Tensor o_t; - o_t.Resize(oi_t->dims()); - o_t.mutable_data(ctx.GetPlace()); - auto o_eig = framework::EigenVector::Flatten(o_t); - o_eig.device(dev) = oi_eig - oj_eig; - - // dOi & dOj - d_oi_eig.device(dev) = - d_out_eig * (o_eig.exp() / (1. + o_eig.exp()) - p_eig); - d_oj_eig.device(dev) = -d_oi_eig; - // dP - framework::EigenVector::Flatten(*d_p).device(dev) = -o_eig; + auto d_out = framework::EigenVector::Flatten(*d_out_t); + auto label = framework::EigenVector::Flatten(*label_t); + auto left = framework::EigenVector::Flatten(*left_t); + auto right = framework::EigenVector::Flatten(*right_t); + + // compute d_left + if (d_left_t) { + d_left_t->mutable_data(ctx.GetPlace()); + auto d_left = framework::EigenVector::Flatten(*d_left_t); + d_left.device(dev) = d_out * (1. / (1. + (right - left).exp()) - label); + } + // compute d_right + if (d_right_t) { + d_right_t->mutable_data(ctx.GetPlace()); + auto d_right = framework::EigenVector::Flatten(*d_right_t); + d_right.device(dev) = + -d_out * (1.0 / (1. + (right - left).exp()) - label); + } } }; } // namespace operators diff --git a/python/paddle/v2/framework/tests/test_rank_loss_op.py b/python/paddle/v2/framework/tests/test_rank_loss_op.py index 48354b7f7bdb4972b3124fd358577d1c4db9cde3..c4d74e1c0402a80fce6a98f2261612d09c06a9cd 100644 --- a/python/paddle/v2/framework/tests/test_rank_loss_op.py +++ b/python/paddle/v2/framework/tests/test_rank_loss_op.py @@ -3,24 +3,29 @@ import numpy as np from op_test import OpTest -class TestReshapeOp(OpTest): +class TestRankLossOp(OpTest): def setUp(self): self.op_type = "rank_loss" - num = 5 - # P = {0, 1.0} or {0, 0.5, 1.0} - P = np.random.randint(0, 2, size=(num, num)).astype("float32") - Oi = np.random.random((num, num)).astype("float32") - Oj = np.random.random((num, num)).astype("float32") - O = Oi - Oj - Out = np.log(1.0 + np.exp(O)) - P * O - self.inputs = {'P': P, 'Oi': Oi, 'Oj': Oj} - self.outputs = {'Out': Out} + batch_size = 5 + # labels_{i} = {0, 1.0} or {0, 0.5, 1.0} + label = np.random.randint(0, 2, size=(batch_size, )).astype("float32") + left = np.random.random((batch_size, )).astype("float32") + right = np.random.random((batch_size, )).astype("float32") + loss = np.log(1.0 + np.exp(left - right)) - label * (left - right) + self.inputs = {'Label': label, 'Left': left, 'Right': right} + self.outputs = {'Out': loss} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(["Oj"], "Out") + self.check_grad(["Left", "Right"], "Out") + + def test_check_grad_ignore_left(self): + self.check_grad(["Right"], "Out", no_grad_set=set('Left')) + + def test_check_grad_ignore_right(self): + self.check_grad(["Left"], "Out", no_grad_set=set('Right')) if __name__ == '__main__':