diff --git a/lite/core/mir/elimination/assign_value_eliminate_pass.cc b/lite/core/mir/elimination/assign_value_eliminate_pass.cc index 0d2c2207875830c2347d8bf337931b1eb6aa6462..e011e2242abd3a263c248fba9294535249b6ef4e 100644 --- a/lite/core/mir/elimination/assign_value_eliminate_pass.cc +++ b/lite/core/mir/elimination/assign_value_eliminate_pass.cc @@ -44,12 +44,15 @@ class Eliminator : public FuseBase { auto* scope = assign_node->stmt()->op()->scope(); auto* op_info = assign_node->stmt()->op()->op_info(); auto shape = op_info->GetAttr>("shape"); + std::vector out_shape; + for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); auto dtype = op_info->GetAttr("dtype"); auto fp32_values = op_info->GetAttr>("fp32_values"); auto int32_values = op_info->GetAttr>("int32_values"); auto* out = matched.at("out"); - auto* out_tensor = scope->FindVar(out->arg()->name) - ->GetMutable(); + auto* out_tensor = + scope->FindVar(out->arg()->name)->GetMutable(); + out_tensor->Resize(out_shape); if (dtype == static_cast(lite::core::FluidType::INT32)) { TensorFromVector(int32_values, out_tensor); } else if (dtype == static_cast(lite::core::FluidType::FP32)) { @@ -58,7 +61,6 @@ class Eliminator : public FuseBase { LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype; } GraphSafeRemoveNodes(graph, {matched.at("assign_value")}); - } };