From dc202c2555ad50e5c8cea7ce59aade99d6fecc24 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 16 Apr 2019 22:09:34 +0800 Subject: [PATCH] fix/positive negative pair op (#16895) * fix infershape in runtime * fix infershape in runtime test=develop * fix infershape in runtime --- .../operators/positive_negative_pair_op.cc | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index 99256e408d4..e917e778e41 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { auto query_dim = ctx->GetInputDim("QueryID"); PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor."); PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor."); - PADDLE_ENFORCE_EQ( - label_dim[0], score_dim[0], - "Tensor Score and Label should have the same height (batch size)."); - PADDLE_ENFORCE_EQ(label_dim[1], 1, - "The width of Label should be 1, i.e. each item should " - "have a scalar label."); - PADDLE_ENFORCE(query_dim == label_dim, - "QueryID should have the same shape as Label."); - if (ctx->HasInput("Weight")) { - PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim, - "Weight should have the same shape as Label."); + + if (ctx->IsRuntime() || + (score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) { + PADDLE_ENFORCE_EQ( + label_dim[0], score_dim[0], + "Tensor Score and Label should have the same height (batch size)."); + + PADDLE_ENFORCE_EQ(label_dim[1], 1, + "The width of Label should be 1, i.e. each item should " + "have a scalar label."); + + PADDLE_ENFORCE(query_dim == label_dim, + "QueryID should have the same shape as Label."); + + if (ctx->HasInput("Weight")) { + PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim, + "Weight should have the same shape as Label."); + } + + int column = ctx->Attrs().Get("column"); + auto depth = score_dim[1]; + PADDLE_ENFORCE(column < depth && column >= -depth, + "Attribute column should be in the range of [-%l, %l)", + depth, depth); } - int column = ctx->Attrs().Get("column"); - auto depth = score_dim[1]; - PADDLE_ENFORCE(column < depth && column >= -depth, - "Attribute column should be in the range of [-%l, %l)", - depth, depth); ctx->SetOutputDim("PositivePair", scalar_dim); ctx->SetOutputDim("NegativePair", scalar_dim); -- GitLab