diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index da7fa1b81d601f4dd03d6716de601a4b1abc7fa0..5edc233f6f73262c3d1b803aae0089f5b15d403d 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -164,7 +164,9 @@ class MergeLoDTensorInferShape : public framework::InferShapeBase { auto mask_dim = context->GetInputDim("Mask"); PADDLE_ENFORCE_EQ(mask_dim.size(), 2); - PADDLE_ENFORCE_EQ(mask_dim[1], 1); + if (context->IsRuntime() || mask_dim[1] > 0) { + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + } context->SetOutputDim("Out", context->GetInputDim("InTrue")); }