diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 1391148ccf5d13082cb31ef2e143249e8ef95bfc..67f7510e874d4b3dcb857510e42cbfa7081becfe 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel { if (framework::product(in_dim) == 0) { in_dim = x_dim; } else { - PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(in_dim, x_dim, + "Input tensors must have same shape"); + } else { + PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(), + "Input tensors must have same shape size"); + // if in_dim or x_dim has -1, not check equal + for (int i = 0; i < x_dim.size(); ++i) { + if (x_dim[i] == -1 || in_dim[i] == -1) { + continue; + } + PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i], + "Input tensors must have same shape if not -1"); + } + } } } ctx->SetOutputDim("Out", in_dim);