提交 dc202c25 编写于 作者: T tangwei12 提交者: phlrain

fix/positive negative pair op (#16895)

* fix infershape in runtime

* fix infershape in runtime
test=develop

* fix infershape in runtime
上级 b20586cf
...@@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
auto query_dim = ctx->GetInputDim("QueryID"); auto query_dim = ctx->GetInputDim("QueryID");
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor."); PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor."); PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
label_dim[0], score_dim[0], if (ctx->IsRuntime() ||
"Tensor Score and Label should have the same height (batch size)."); (score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) {
PADDLE_ENFORCE_EQ(label_dim[1], 1, PADDLE_ENFORCE_EQ(
"The width of Label should be 1, i.e. each item should " label_dim[0], score_dim[0],
"have a scalar label."); "Tensor Score and Label should have the same height (batch size).");
PADDLE_ENFORCE(query_dim == label_dim,
"QueryID should have the same shape as Label."); PADDLE_ENFORCE_EQ(label_dim[1], 1,
if (ctx->HasInput("Weight")) { "The width of Label should be 1, i.e. each item should "
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim, "have a scalar label.");
"Weight should have the same shape as Label.");
PADDLE_ENFORCE(query_dim == label_dim,
"QueryID should have the same shape as Label.");
if (ctx->HasInput("Weight")) {
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
"Weight should have the same shape as Label.");
}
int column = ctx->Attrs().Get<int>("column");
auto depth = score_dim[1];
PADDLE_ENFORCE(column < depth && column >= -depth,
"Attribute column should be in the range of [-%l, %l)",
depth, depth);
} }
int column = ctx->Attrs().Get<int>("column");
auto depth = score_dim[1];
PADDLE_ENFORCE(column < depth && column >= -depth,
"Attribute column should be in the range of [-%l, %l)",
depth, depth);
ctx->SetOutputDim("PositivePair", scalar_dim); ctx->SetOutputDim("PositivePair", scalar_dim);
ctx->SetOutputDim("NegativePair", scalar_dim); ctx->SetOutputDim("NegativePair", scalar_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册