提交 67327685 编写于 作者: M Megvii Engine Team

fix(lite): capture output by ref in io callback

GitOrigin-RevId: 6d23ec2f89ff5bcc2208beec3425d1893ea3754c
上级 042c7fd1
...@@ -915,6 +915,13 @@ void NetworkImplDft::update_output() { ...@@ -915,6 +915,13 @@ void NetworkImplDft::update_output() {
void NetworkImplDft::output_tensor_copy_optimize( void NetworkImplDft::output_tensor_copy_optimize(
Var var, std::shared_ptr<Tensor> tensor) { Var var, std::shared_ptr<Tensor> tensor) {
size_t index;
for (index = 0; index < m_load_result.output_var_list.size(); ++index) {
if (m_load_result.output_var_list[index].node() == var.node()) {
break;
}
}
LITE_ASSERT(index != m_load_result.output_var_list.size());
LITE_ASSERT( LITE_ASSERT(
!(m_user_config->options.force_output_use_user_specified_memory && !(m_user_config->options.force_output_use_user_specified_memory &&
m_user_config->options.force_output_dynamic_alloc), m_user_config->options.force_output_dynamic_alloc),
...@@ -924,23 +931,26 @@ void NetworkImplDft::output_tensor_copy_optimize( ...@@ -924,23 +931,26 @@ void NetworkImplDft::output_tensor_copy_optimize(
bool in_record = m_user_config->options.comp_node_seq_record_level > 0; bool in_record = m_user_config->options.comp_node_seq_record_level > 0;
TensorHelper::implement(tensor) TensorHelper::implement(tensor)
->cast_final_safe<TensorImplDft>() ->cast_final_safe<TensorImplDft>()
.set_reset_callback([var, in_record](TensorImplDft* dft_tensor) { .set_reset_callback(
dft_tensor->device_share_host_memory(); [this, index, in_record](TensorImplDft* dft_tensor) {
auto dv = dft_tensor->dev_tensor().get(); auto var = this->m_load_result.output_var_list[index];
dv->comp_node(var.node()->comp_node(), true); dft_tensor->device_share_host_memory();
var.node()->init_mem_plan(dv); auto dv = dft_tensor->dev_tensor().get();
if (in_record) { dv->comp_node(var.node()->comp_node(), true);
auto&& device_tensor = var.node()->mutable_dev_tensor(); var.node()->init_mem_plan(dv);
device_tensor.only_reset_raw_storage(dv->storage()); if (in_record) {
} else { auto&& device_tensor = var.node()->mutable_dev_tensor();
var.node()->reset_dev_tensor_from_tensor(*dv); device_tensor.only_reset_raw_storage(dv->storage());
} } else {
}); var.node()->reset_dev_tensor_from_tensor(*dv);
}
});
} }
if (m_user_config->options.force_output_dynamic_alloc) { if (m_user_config->options.force_output_dynamic_alloc) {
TensorHelper::implement(tensor) TensorHelper::implement(tensor)
->cast_final_safe<TensorImplDft>() ->cast_final_safe<TensorImplDft>()
.set_get_memory_callback([var](TensorImplDft* dft_tensor) { .set_get_memory_callback([this, index](TensorImplDft* dft_tensor) {
auto var = this->m_load_result.output_var_list[index];
if (dft_tensor->is_host()) { if (dft_tensor->is_host()) {
auto host_tensor = dft_tensor->m_host_tensor; auto host_tensor = dft_tensor->m_host_tensor;
*host_tensor = *host_tensor =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册