提交 87be315c 编写于 作者: H Hongyu Liu 提交者: phlrain

Merge pull request #16897 from velconia/fix_split_lod_tensor_op_infer_shape

Fix infer shape of split lod tensor op
上级 2055c16d
...@@ -157,7 +157,9 @@ class SplitLoDTensorInferShape : public framework::InferShapeBase { ...@@ -157,7 +157,9 @@ class SplitLoDTensorInferShape : public framework::InferShapeBase {
auto mask_dim = context->GetInputDim("Mask"); auto mask_dim = context->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(mask_dim.size(), 2); PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
PADDLE_ENFORCE_EQ(mask_dim[1], 1); if (context->IsRuntime()) {
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
}
context->SetOutputDim("OutTrue", context->GetInputDim("X")); context->SetOutputDim("OutTrue", context->GetInputDim("X"));
context->SetOutputDim("OutFalse", context->GetInputDim("X")); context->SetOutputDim("OutFalse", context->GetInputDim("X"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册