From 87447df01ccf9906df058567cfb1c03b61776d58 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 17 Nov 2022 22:16:44 +0800 Subject: [PATCH] fix(lite): fix lite zero copy GitOrigin-RevId: 3ceacb452438c1ad5987890bc513fc0c97787b82 --- .gitattributes | 1 + lite/src/mge/network_impl.cpp | 20 +++++++++++--------- lite/src/mge/tensor_impl.cpp | 9 ++++++--- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.gitattributes b/.gitattributes index 40d05d0e3..8ba02f1c4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -30,3 +30,4 @@ lite/test/resource/lite/resnet50_input.npy filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/resnet50.mge filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/resnet50_uint8.mge filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/cat.ppm filter=lfs diff=lfs merge=lfs -text +lite/test/resource/lite/ax_model_zero_copy.mge filter=lfs diff=lfs merge=lfs -text diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 772c5d40b..e074cef13 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -585,17 +585,18 @@ void NetworkImplDft::configure_after_loaded() { cross_compnode_model_detect(); //! update the IO of the network - update_io(); + update_input(); + replace_dev_input_pass(); + if (!m_user_config->discrete_input_name.empty()) { + replace_src_discrete_input_opr_pass(); + } + update_output(); //! replace the IO when there is device input or output compile_graph(); } void NetworkImplDft::compile_graph() { - replace_dev_input_pass(); - if (!m_user_config->discrete_input_name.empty()) { - replace_src_discrete_input_opr_pass(); - } make_output_spec(); m_execute_func = m_load_result.graph_compile(m_output_spec); } @@ -882,24 +883,25 @@ void NetworkImplDft::update_output() { } else { for (auto&& out : m_load_result.output_var_list) { std::shared_ptr lite_tensor = nullptr; + auto device = get_device_from_locator(out.node()->comp_node().locator()); auto it = std::find_if( m_network_io->outputs.begin(), m_network_io->outputs.end(), [&out](const IOInner io) { return io.name == out.node()->name(); }); if (it != m_network_io->outputs.end()) { if (it->is_host) { it->lite_tensor = std::make_shared( - device_id, stream_id, device_type, true); + device_id, stream_id, device, true); } else { it->lite_tensor = - std::make_shared(device_id, stream_id, device_type); + std::make_shared(device_id, stream_id, device); } try_infer_tensor_layout(it->lite_tensor, out); lite_tensor = it->lite_tensor; } else { IOInner output; output.name = out.node()->name(); - output.lite_tensor = std::make_shared( - device_id, stream_id, device_type, true); + output.lite_tensor = + std::make_shared(device_id, stream_id, device, true); m_network_io->outputs.push_back({output}); try_infer_tensor_layout(output.lite_tensor, out); lite_tensor = output.lite_tensor; diff --git a/lite/src/mge/tensor_impl.cpp b/lite/src/mge/tensor_impl.cpp index 4816a9892..30203932b 100644 --- a/lite/src/mge/tensor_impl.cpp +++ b/lite/src/mge/tensor_impl.cpp @@ -452,9 +452,12 @@ void TensorImplDft::device_share_host_memory() { m_host_tensor->comp_node(), m_host_tensor->layout()); } if (m_host_tensor->raw_ptr() != m_dev_tensor->raw_ptr()) { - auto&& storage = - mgb::DeviceTensorStorage::make_proxy(m_host_tensor->storage()); - m_dev_tensor->only_reset_raw_storage(storage); + auto& host_storage = m_host_tensor->storage(); + mgb::DeviceTensorStorage device_storage; + device_storage.reset( + host_storage.comp_node(), host_storage.size(), + host_storage.raw_storage()); + m_dev_tensor->only_reset_raw_storage(device_storage); } } } -- GitLab