diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 4bec1baeaaee94942be33a86ff2165dd98da5818..0900ed2ff2f5d46c9705885e0847c92249091afc 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -348,17 +348,30 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } else { // If is not in runtime, we will set default value(-1) for ScalarArray - int64_t num_ele = 1; + int64_t num_ele = 0; std::vector vars; vars.reserve(infershape_inputs.size()); for (size_t i = 0; i < infershape_inputs.size(); i++) { vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); } - for (auto& var : vars) { - const auto& tensor_dims = var->GetShape(); + + if (vars.size() == 1) { + num_ele = 1; + const auto& tensor_dims = vars[0]->GetShape(); for (size_t i = 0; i < tensor_dims.size(); ++i) { num_ele *= tensor_dims[i]; } + } else { + for (auto& var : vars) { + const auto& tensor_dims = var->GetShape(); + PADDLE_ENFORCE_EQ(tensor_dims.size(), 1, + platform::errors::InvalidArgument( + "The shape is constructed by multi-tensor, " + "every tensor's dims should be 1. But your " + "shape has tensor that dims is %s.", + tensor_dims.size())); + num_ele += tensor_dims[0]; + } } phi::ScalarArray tensor_attr(std::vector(num_ele, -1)); tensor_attr.SetFromTensor(true);