diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 58eb463e9228797cfaf1d5b5c6e4e38e3716f335..772c5d40bffb5fec7eeb565fc0a46bf35b7ee59c 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -915,6 +915,13 @@ void NetworkImplDft::update_output() { void NetworkImplDft::output_tensor_copy_optimize( Var var, std::shared_ptr tensor) { + size_t index; + for (index = 0; index < m_load_result.output_var_list.size(); ++index) { + if (m_load_result.output_var_list[index].node() == var.node()) { + break; + } + } + LITE_ASSERT(index != m_load_result.output_var_list.size()); LITE_ASSERT( !(m_user_config->options.force_output_use_user_specified_memory && m_user_config->options.force_output_dynamic_alloc), @@ -924,23 +931,26 @@ void NetworkImplDft::output_tensor_copy_optimize( bool in_record = m_user_config->options.comp_node_seq_record_level > 0; TensorHelper::implement(tensor) ->cast_final_safe() - .set_reset_callback([var, in_record](TensorImplDft* dft_tensor) { - dft_tensor->device_share_host_memory(); - auto dv = dft_tensor->dev_tensor().get(); - dv->comp_node(var.node()->comp_node(), true); - var.node()->init_mem_plan(dv); - if (in_record) { - auto&& device_tensor = var.node()->mutable_dev_tensor(); - device_tensor.only_reset_raw_storage(dv->storage()); - } else { - var.node()->reset_dev_tensor_from_tensor(*dv); - } - }); + .set_reset_callback( + [this, index, in_record](TensorImplDft* dft_tensor) { + auto var = this->m_load_result.output_var_list[index]; + dft_tensor->device_share_host_memory(); + auto dv = dft_tensor->dev_tensor().get(); + dv->comp_node(var.node()->comp_node(), true); + var.node()->init_mem_plan(dv); + if (in_record) { + auto&& device_tensor = var.node()->mutable_dev_tensor(); + device_tensor.only_reset_raw_storage(dv->storage()); + } else { + var.node()->reset_dev_tensor_from_tensor(*dv); + } + }); } if (m_user_config->options.force_output_dynamic_alloc) { TensorHelper::implement(tensor) ->cast_final_safe() - .set_get_memory_callback([var](TensorImplDft* dft_tensor) { + .set_get_memory_callback([this, index](TensorImplDft* dft_tensor) { + auto var = this->m_load_result.output_var_list[index]; if (dft_tensor->is_host()) { auto host_tensor = dft_tensor->m_host_tensor; *host_tensor =