From af3c5d921d430518345bb704396da88f63eff2de Mon Sep 17 00:00:00 2001 From: yongqiangma Date: Thu, 9 Apr 2020 13:30:25 +0000 Subject: [PATCH] resize assign_value output tensor. test=develop --- lite/core/mir/elimination/assign_value_eliminate_pass.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lite/core/mir/elimination/assign_value_eliminate_pass.cc b/lite/core/mir/elimination/assign_value_eliminate_pass.cc index 0d2c220787..e011e2242a 100644 --- a/lite/core/mir/elimination/assign_value_eliminate_pass.cc +++ b/lite/core/mir/elimination/assign_value_eliminate_pass.cc @@ -44,12 +44,15 @@ class Eliminator : public FuseBase { auto* scope = assign_node->stmt()->op()->scope(); auto* op_info = assign_node->stmt()->op()->op_info(); auto shape = op_info->GetAttr>("shape"); + std::vector out_shape; + for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); auto dtype = op_info->GetAttr("dtype"); auto fp32_values = op_info->GetAttr>("fp32_values"); auto int32_values = op_info->GetAttr>("int32_values"); auto* out = matched.at("out"); - auto* out_tensor = scope->FindVar(out->arg()->name) - ->GetMutable(); + auto* out_tensor = + scope->FindVar(out->arg()->name)->GetMutable(); + out_tensor->Resize(out_shape); if (dtype == static_cast(lite::core::FluidType::INT32)) { TensorFromVector(int32_values, out_tensor); } else if (dtype == static_cast(lite::core::FluidType::FP32)) { @@ -58,7 +61,6 @@ class Eliminator : public FuseBase { LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype; } GraphSafeRemoveNodes(graph, {matched.at("assign_value")}); - } }; -- GitLab