未验证 提交 fed6de40 编写于 作者: Y YuanRisheng 提交者: GitHub

[Bug Fixes]Fix Bugs when construct infermeta by using shape(Vector<Tensor>) (#39904)

* fix bugs

* fix bugs
上级 ef96ffb6
......@@ -348,17 +348,30 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 1;
int64_t num_ele = 0;
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
if (vars.size() == 1) {
num_ele = 1;
const auto& tensor_dims = vars[0]->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i];
}
} else {
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
PADDLE_ENFORCE_EQ(tensor_dims.size(), 1,
platform::errors::InvalidArgument(
"The shape is constructed by multi-tensor, "
"every tensor's dims should be 1. But your "
"shape has tensor that dims is %s.",
tensor_dims.size()));
num_ele += tensor_dims[0];
}
}
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册