diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index fd3ac86939fcb47f79d538c4047719473cc0cf70..d98fd54f22841453582f989776bdc0c7cdc1f0fa 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -1,4 +1,3 @@ - /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); @@ -37,11 +36,10 @@ class RankLossOp : public framework::OperatorWithKernel { auto label_dims = ctx.Input("Label")->dims(); auto left_dims = ctx.Input("Left")->dims(); auto right_dims = ctx.Input("Right")->dims(); - PADDLE_ENFORCE((label_dims.size() == 1) && (left_dims.size() == 1) && - (right_dims.size() == 1), - "The rank of all inputs must be 1."); PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), "All inputs must have the same size"); + PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), + "All inputs must be row vector with size batch_sizex1."); ctx.Output("Out")->Resize(label_dims); } }; @@ -52,10 +50,10 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Label", - "The label indicating A ranked higher than B or not, 1-D tensor."); - AddInput("Left", "The output of RankNet for doc A, 1-D tensor."); - AddInput("Right", "The output of RankNet for doc B, 1-D tensor"); - AddOutput("Out", "The output loss of RankLoss operator, 1-D tensor."); + "The label indicating A ranked higher than B or not, row vector."); + AddInput("Left", "The output of RankNet for doc A, vector."); + AddInput("Right", "The output of RankNet for doc B, vetor"); + AddOutput("Out", "The output loss of RankLoss operator, vector."); AddComment(R"DOC(RankLoss operator Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with diff --git a/python/paddle/v2/framework/tests/test_rank_loss_op.py b/python/paddle/v2/framework/tests/test_rank_loss_op.py index c4d74e1c0402a80fce6a98f2261612d09c06a9cd..0e41ab1b3fd8fa8b62c5f3b914b752918119a265 100644 --- a/python/paddle/v2/framework/tests/test_rank_loss_op.py +++ b/python/paddle/v2/framework/tests/test_rank_loss_op.py @@ -8,9 +8,9 @@ class TestRankLossOp(OpTest): self.op_type = "rank_loss" batch_size = 5 # labels_{i} = {0, 1.0} or {0, 0.5, 1.0} - label = np.random.randint(0, 2, size=(batch_size, )).astype("float32") - left = np.random.random((batch_size, )).astype("float32") - right = np.random.random((batch_size, )).astype("float32") + label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32") + left = np.random.random((batch_size, 1)).astype("float32") + right = np.random.random((batch_size, 1)).astype("float32") loss = np.log(1.0 + np.exp(left - right)) - label * (left - right) self.inputs = {'Label': label, 'Left': left, 'Right': right} self.outputs = {'Out': loss}