提交 ece74510 编写于 作者: Z zhaoyuchen2018 提交者: phlrain

Merge pull request #16857 from zhaoyuchen2018/sumreshape

Fix sum infershape issue
上级 a9539cbf
......@@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel {
if (framework::product(in_dim) == 0) {
in_dim = x_dim;
} else {
PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(in_dim, x_dim,
"Input tensors must have same shape");
} else {
PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
"Input tensors must have same shape size");
// if in_dim or x_dim has -1, not check equal
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] == -1 || in_dim[i] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
"Input tensors must have same shape if not -1");
}
}
}
}
ctx->SetOutputDim("Out", in_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册