未验证 提交 d966faae 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16852 from sneaxiy/fix_merge_lod_tensor_op_infer_shape

Fix merge_lod_tensor_op infer shape
...@@ -164,7 +164,9 @@ class MergeLoDTensorInferShape : public framework::InferShapeBase { ...@@ -164,7 +164,9 @@ class MergeLoDTensorInferShape : public framework::InferShapeBase {
auto mask_dim = context->GetInputDim("Mask"); auto mask_dim = context->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(mask_dim.size(), 2); PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
if (context->IsRuntime() || mask_dim[1] > 0) {
PADDLE_ENFORCE_EQ(mask_dim[1], 1); PADDLE_ENFORCE_EQ(mask_dim[1], 1);
}
context->SetOutputDim("Out", context->GetInputDim("InTrue")); context->SetOutputDim("Out", context->GetInputDim("InTrue"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册