提交 87447df0 编写于 作者: M Megvii Engine Team

fix(lite): fix lite zero copy

GitOrigin-RevId: 3ceacb452438c1ad5987890bc513fc0c97787b82
上级 f16e4311
...@@ -30,3 +30,4 @@ lite/test/resource/lite/resnet50_input.npy filter=lfs diff=lfs merge=lfs -text ...@@ -30,3 +30,4 @@ lite/test/resource/lite/resnet50_input.npy filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/resnet50.mge filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/resnet50.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/resnet50_uint8.mge filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/resnet50_uint8.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/cat.ppm filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/cat.ppm filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_model_zero_copy.mge filter=lfs diff=lfs merge=lfs -text
...@@ -585,17 +585,18 @@ void NetworkImplDft::configure_after_loaded() { ...@@ -585,17 +585,18 @@ void NetworkImplDft::configure_after_loaded() {
cross_compnode_model_detect(); cross_compnode_model_detect();
//! update the IO of the network //! update the IO of the network
update_io(); update_input();
replace_dev_input_pass();
if (!m_user_config->discrete_input_name.empty()) {
replace_src_discrete_input_opr_pass();
}
update_output();
//! replace the IO when there is device input or output //! replace the IO when there is device input or output
compile_graph(); compile_graph();
} }
void NetworkImplDft::compile_graph() { void NetworkImplDft::compile_graph() {
replace_dev_input_pass();
if (!m_user_config->discrete_input_name.empty()) {
replace_src_discrete_input_opr_pass();
}
make_output_spec(); make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec); m_execute_func = m_load_result.graph_compile(m_output_spec);
} }
...@@ -882,24 +883,25 @@ void NetworkImplDft::update_output() { ...@@ -882,24 +883,25 @@ void NetworkImplDft::update_output() {
} else { } else {
for (auto&& out : m_load_result.output_var_list) { for (auto&& out : m_load_result.output_var_list) {
std::shared_ptr<Tensor> lite_tensor = nullptr; std::shared_ptr<Tensor> lite_tensor = nullptr;
auto device = get_device_from_locator(out.node()->comp_node().locator());
auto it = std::find_if( auto it = std::find_if(
m_network_io->outputs.begin(), m_network_io->outputs.end(), m_network_io->outputs.begin(), m_network_io->outputs.end(),
[&out](const IOInner io) { return io.name == out.node()->name(); }); [&out](const IOInner io) { return io.name == out.node()->name(); });
if (it != m_network_io->outputs.end()) { if (it != m_network_io->outputs.end()) {
if (it->is_host) { if (it->is_host) {
it->lite_tensor = std::make_shared<Tensor>( it->lite_tensor = std::make_shared<Tensor>(
device_id, stream_id, device_type, true); device_id, stream_id, device, true);
} else { } else {
it->lite_tensor = it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type); std::make_shared<Tensor>(device_id, stream_id, device);
} }
try_infer_tensor_layout(it->lite_tensor, out); try_infer_tensor_layout(it->lite_tensor, out);
lite_tensor = it->lite_tensor; lite_tensor = it->lite_tensor;
} else { } else {
IOInner output; IOInner output;
output.name = out.node()->name(); output.name = out.node()->name();
output.lite_tensor = std::make_shared<Tensor>( output.lite_tensor =
device_id, stream_id, device_type, true); std::make_shared<Tensor>(device_id, stream_id, device, true);
m_network_io->outputs.push_back({output}); m_network_io->outputs.push_back({output});
try_infer_tensor_layout(output.lite_tensor, out); try_infer_tensor_layout(output.lite_tensor, out);
lite_tensor = output.lite_tensor; lite_tensor = output.lite_tensor;
......
...@@ -452,9 +452,12 @@ void TensorImplDft::device_share_host_memory() { ...@@ -452,9 +452,12 @@ void TensorImplDft::device_share_host_memory() {
m_host_tensor->comp_node(), m_host_tensor->layout()); m_host_tensor->comp_node(), m_host_tensor->layout());
} }
if (m_host_tensor->raw_ptr() != m_dev_tensor->raw_ptr()) { if (m_host_tensor->raw_ptr() != m_dev_tensor->raw_ptr()) {
auto&& storage = auto& host_storage = m_host_tensor->storage();
mgb::DeviceTensorStorage::make_proxy(m_host_tensor->storage()); mgb::DeviceTensorStorage device_storage;
m_dev_tensor->only_reset_raw_storage(storage); device_storage.reset(
host_storage.comp_node(), host_storage.size(),
host_storage.raw_storage());
m_dev_tensor->only_reset_raw_storage(device_storage);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册