From 67327685b188df6f29e97b1a3a1805b42354d272 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 17 Nov 2022 15:50:28 +0800 Subject: [PATCH] fix(lite): capture output by ref in io callback GitOrigin-RevId: 6d23ec2f89ff5bcc2208beec3425d1893ea3754c --- lite/src/mge/network_impl.cpp | 36 ++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 58eb463e9..772c5d40b 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -915,6 +915,13 @@ void NetworkImplDft::update_output() { void NetworkImplDft::output_tensor_copy_optimize( Var var, std::shared_ptr tensor) { + size_t index; + for (index = 0; index < m_load_result.output_var_list.size(); ++index) { + if (m_load_result.output_var_list[index].node() == var.node()) { + break; + } + } + LITE_ASSERT(index != m_load_result.output_var_list.size()); LITE_ASSERT( !(m_user_config->options.force_output_use_user_specified_memory && m_user_config->options.force_output_dynamic_alloc), @@ -924,23 +931,26 @@ void NetworkImplDft::output_tensor_copy_optimize( bool in_record = m_user_config->options.comp_node_seq_record_level > 0; TensorHelper::implement(tensor) ->cast_final_safe() - .set_reset_callback([var, in_record](TensorImplDft* dft_tensor) { - dft_tensor->device_share_host_memory(); - auto dv = dft_tensor->dev_tensor().get(); - dv->comp_node(var.node()->comp_node(), true); - var.node()->init_mem_plan(dv); - if (in_record) { - auto&& device_tensor = var.node()->mutable_dev_tensor(); - device_tensor.only_reset_raw_storage(dv->storage()); - } else { - var.node()->reset_dev_tensor_from_tensor(*dv); - } - }); + .set_reset_callback( + [this, index, in_record](TensorImplDft* dft_tensor) { + auto var = this->m_load_result.output_var_list[index]; + dft_tensor->device_share_host_memory(); + auto dv = dft_tensor->dev_tensor().get(); + dv->comp_node(var.node()->comp_node(), true); + var.node()->init_mem_plan(dv); + if (in_record) { + auto&& device_tensor = var.node()->mutable_dev_tensor(); + device_tensor.only_reset_raw_storage(dv->storage()); + } else { + var.node()->reset_dev_tensor_from_tensor(*dv); + } + }); } if (m_user_config->options.force_output_dynamic_alloc) { TensorHelper::implement(tensor) ->cast_final_safe() - .set_get_memory_callback([var](TensorImplDft* dft_tensor) { + .set_get_memory_callback([this, index](TensorImplDft* dft_tensor) { + auto var = this->m_load_result.output_var_list[index]; if (dft_tensor->is_host()) { auto host_tensor = dft_tensor->m_host_tensor; *host_tensor = -- GitLab