From fed6de40475ea796a195c9471f86b193ec62c11c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 25 Feb 2022 10:31:54 +0800 Subject: [PATCH] [Bug Fixes]Fix Bugs when construct infermeta by using shape(Vector) (#39904) * fix bugs * fix bugs --- paddle/fluid/framework/infershape_utils.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 4bec1baeaae..0900ed2ff2f 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -348,17 +348,30 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } else { // If is not in runtime, we will set default value(-1) for ScalarArray - int64_t num_ele = 1; + int64_t num_ele = 0; std::vector vars; vars.reserve(infershape_inputs.size()); for (size_t i = 0; i < infershape_inputs.size(); i++) { vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); } - for (auto& var : vars) { - const auto& tensor_dims = var->GetShape(); + + if (vars.size() == 1) { + num_ele = 1; + const auto& tensor_dims = vars[0]->GetShape(); for (size_t i = 0; i < tensor_dims.size(); ++i) { num_ele *= tensor_dims[i]; } + } else { + for (auto& var : vars) { + const auto& tensor_dims = var->GetShape(); + PADDLE_ENFORCE_EQ(tensor_dims.size(), 1, + platform::errors::InvalidArgument( + "The shape is constructed by multi-tensor, " + "every tensor's dims should be 1. But your " + "shape has tensor that dims is %s.", + tensor_dims.size())); + num_ele += tensor_dims[0]; + } } phi::ScalarArray tensor_attr(std::vector(num_ele, -1)); tensor_attr.SetFromTensor(true); -- GitLab