From ece745104de2797e71eee0ebf0f203f158b3a9bc Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Tue, 16 Apr 2019 09:55:20 +0800 Subject: [PATCH] Merge pull request #16857 from zhaoyuchen2018/sumreshape Fix sum infershape issue --- paddle/fluid/operators/sum_op.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 1391148cc..67f7510e8 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel { if (framework::product(in_dim) == 0) { in_dim = x_dim; } else { - PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(in_dim, x_dim, + "Input tensors must have same shape"); + } else { + PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(), + "Input tensors must have same shape size"); + // if in_dim or x_dim has -1, not check equal + for (int i = 0; i < x_dim.size(); ++i) { + if (x_dim[i] == -1 || in_dim[i] == -1) { + continue; + } + PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i], + "Input tensors must have same shape if not -1"); + } + } } } ctx->SetOutputDim("Out", in_dim); -- GitLab