提交 f0985cec 编写于 作者: S superjomn

fix logical op infershape

test=develop
上级 1c526e1d
...@@ -71,8 +71,14 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase { ...@@ -71,8 +71,14 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase {
"Input(Y) of %s operator must not be null", comment.type); "Input(Y) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same"); int product_x = framework::product(dim_x);
int product_y = framework::product(dim_y);
bool check = ctx->IsRuntime() && product_x >= 0 && product_y >= 0;
if (check) {
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same");
}
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册