From 137e6336efbead4c4bc8ed627401eda7dd142089 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 26 Sep 2019 15:15:23 +0800 Subject: [PATCH] Remove constraint that last dimension is forced to be 1 in rank_loss (#19997) * fix input shape check test=develop * move PADDLE_ENFORCE test=develop --- paddle/fluid/operators/rank_loss_op.cc | 75 ++++++++++++++----- .../tests/unittests/test_rank_loss_op.py | 52 +++++++++++-- 2 files changed, 103 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/rank_loss_op.cc b/paddle/fluid/operators/rank_loss_op.cc index 45daa6b9556..2e53ecc040b 100644 --- a/paddle/fluid/operators/rank_loss_op.cc +++ b/paddle/fluid/operators/rank_loss_op.cc @@ -27,20 +27,55 @@ class RankLossOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - // input check - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true, + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true, + "Input(Left) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true, + "Input(Right) shouldn't be null."); auto label_dims = ctx->GetInputDim("Label"); auto left_dims = ctx->GetInputDim("Left"); auto right_dims = ctx->GetInputDim("Right"); - - PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), - "All inputs must have the same size."); - PADDLE_ENFORCE( - (label_dims.size() == 2) && (label_dims[1] == 1), - "All inputs must be 2-D tensors with shape [batch_size x 1]."); + // check label_dims valid + PADDLE_ENFORCE_GE(label_dims.size(), 1, + "The dimension size of Input(Label) must be greater than " + "or equal to 1."); + PADDLE_ENFORCE_LE( + label_dims.size(), 2, + "The dimension size of Input(Label) must be less than or equal to 2."); + if (label_dims.size() == 2U) { + PADDLE_ENFORCE_EQ(label_dims[1], 1, + "The last dimension of Input(Label) must be 1."); + } + // check left_dims valid + PADDLE_ENFORCE_GE(left_dims.size(), 1, + "The dimension size of Input(Left) must be greater than " + "or equal to 1."); + PADDLE_ENFORCE_LE( + left_dims.size(), 2, + "The dimension size of Input(Left) must be less than or equal to 2."); + if (left_dims.size() == 2U) { + PADDLE_ENFORCE_EQ(left_dims[1], 1, + "The last dimension of Input(Left) must be 1."); + } + // check right_dims valid + PADDLE_ENFORCE_GE(right_dims.size(), 1, + "The dimension size of Input(Right) must be greater than " + "or equal to 1."); + PADDLE_ENFORCE_LE( + right_dims.size(), 2, + "The dimension size of Input(Right) must be less than or equal to 2."); + if (right_dims.size() == 2U) { + PADDLE_ENFORCE_EQ(right_dims[1], 1, + "The last dimension of Input(Right) must be 1."); + } + PADDLE_ENFORCE_EQ(label_dims[0], left_dims[0], + "The first dimension of Input(Label) and Input(Left) " + "must have the same value."); + PADDLE_ENFORCE_EQ(label_dims[0], right_dims[0], + "The first dimension of Input(Label) and Input(Right) " + "must have the same value."); ctx->SetOutputDim("Out", label_dims); } }; @@ -98,21 +133,25 @@ class RankLossGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto dims = ctx->GetInputDim("Left"); + PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true, + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Left"), true, + "Input(Left) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Right"), true, + "Input(Right) shouldn't be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) shouldn't be null."); + auto left_dims = ctx->GetInputDim("Left"); + auto right_dims = ctx->GetInputDim("Right"); auto left_grad_name = framework::GradVarName("Left"); auto right_grad_name = framework::GradVarName("Right"); if (ctx->HasOutput(left_grad_name)) { - ctx->SetOutputDim(left_grad_name, dims); + ctx->SetOutputDim(left_grad_name, left_dims); } if (ctx->HasOutput(right_grad_name)) { - ctx->SetOutputDim(right_grad_name, dims); + ctx->SetOutputDim(right_grad_name, right_dims); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_rank_loss_op.py b/python/paddle/fluid/tests/unittests/test_rank_loss_op.py index c9fa24b103d..733962dff2b 100644 --- a/python/paddle/fluid/tests/unittests/test_rank_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_rank_loss_op.py @@ -22,14 +22,24 @@ from op_test import OpTest class TestRankLossOp(OpTest): def setUp(self): self.op_type = "rank_loss" - batch_size = 5 + shape = (5, 1) # labels_{i} = {0, 1.0} or {0, 0.5, 1.0} - label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32") - left = np.random.random((batch_size, 1)).astype("float32") - right = np.random.random((batch_size, 1)).astype("float32") + label_shape, left_shape, right_shape = self.set_shape() + label = np.random.randint(0, 2, size=shape).astype("float32") + left = np.random.random(shape).astype("float32") + right = np.random.random(shape).astype("float32") loss = np.log(1.0 + np.exp(left - right)) - label * (left - right) - self.inputs = {'Label': label, 'Left': left, 'Right': right} - self.outputs = {'Out': loss} + loss = np.reshape(loss, label_shape) + self.inputs = { + 'Label': label.reshape(label_shape), + 'Left': left.reshape(left_shape), + 'Right': right.reshape(right_shape) + } + self.outputs = {'Out': loss.reshape(label_shape)} + + def set_shape(self): + batch_size = 5 + return (batch_size, 1), (batch_size, 1), (batch_size, 1) def test_check_output(self): self.check_output() @@ -44,5 +54,35 @@ class TestRankLossOp(OpTest): self.check_grad(["Left"], "Out", no_grad_set=set('Right')) +class TestRankLossOp1(TestRankLossOp): + def set_shape(self): + batch_size = 5 + return (batch_size), (batch_size, 1), (batch_size, 1) + + +class TestRankLossOp2(TestRankLossOp): + def set_shape(self): + batch_size = 5 + return (batch_size, 1), (batch_size), (batch_size, 1) + + +class TestRankLossOp3(TestRankLossOp): + def set_shape(self): + batch_size = 5 + return (batch_size, 1), (batch_size, 1), (batch_size) + + +class TestRankLossOp4(TestRankLossOp): + def set_shape(self): + batch_size = 5 + return (batch_size), (batch_size), (batch_size, 1) + + +class TestRankLossOp5(TestRankLossOp): + def set_shape(self): + batch_size = 5 + return (batch_size), (batch_size), (batch_size) + + if __name__ == '__main__': unittest.main() -- GitLab