提交 87447df0 编写于 作者: M Megvii Engine Team

fix(lite): fix lite zero copy

GitOrigin-RevId: 3ceacb452438c1ad5987890bc513fc0c97787b82
上级 f16e4311
......@@ -30,3 +30,4 @@ lite/test/resource/lite/resnet50_input.npy filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/resnet50.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/resnet50_uint8.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/cat.ppm filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_model_zero_copy.mge filter=lfs diff=lfs merge=lfs -text
......@@ -585,17 +585,18 @@ void NetworkImplDft::configure_after_loaded() {
cross_compnode_model_detect();
//! update the IO of the network
update_io();
update_input();
replace_dev_input_pass();
if (!m_user_config->discrete_input_name.empty()) {
replace_src_discrete_input_opr_pass();
}
update_output();
//! replace the IO when there is device input or output
compile_graph();
}
void NetworkImplDft::compile_graph() {
replace_dev_input_pass();
if (!m_user_config->discrete_input_name.empty()) {
replace_src_discrete_input_opr_pass();
}
make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec);
}
......@@ -882,24 +883,25 @@ void NetworkImplDft::update_output() {
} else {
for (auto&& out : m_load_result.output_var_list) {
std::shared_ptr<Tensor> lite_tensor = nullptr;
auto device = get_device_from_locator(out.node()->comp_node().locator());
auto it = std::find_if(
m_network_io->outputs.begin(), m_network_io->outputs.end(),
[&out](const IOInner io) { return io.name == out.node()->name(); });
if (it != m_network_io->outputs.end()) {
if (it->is_host) {
it->lite_tensor = std::make_shared<Tensor>(
device_id, stream_id, device_type, true);
device_id, stream_id, device, true);
} else {
it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type);
std::make_shared<Tensor>(device_id, stream_id, device);
}
try_infer_tensor_layout(it->lite_tensor, out);
lite_tensor = it->lite_tensor;
} else {
IOInner output;
output.name = out.node()->name();
output.lite_tensor = std::make_shared<Tensor>(
device_id, stream_id, device_type, true);
output.lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device, true);
m_network_io->outputs.push_back({output});
try_infer_tensor_layout(output.lite_tensor, out);
lite_tensor = output.lite_tensor;
......
......@@ -452,9 +452,12 @@ void TensorImplDft::device_share_host_memory() {
m_host_tensor->comp_node(), m_host_tensor->layout());
}
if (m_host_tensor->raw_ptr() != m_dev_tensor->raw_ptr()) {
auto&& storage =
mgb::DeviceTensorStorage::make_proxy(m_host_tensor->storage());
m_dev_tensor->only_reset_raw_storage(storage);
auto& host_storage = m_host_tensor->storage();
mgb::DeviceTensorStorage device_storage;
device_storage.reset(
host_storage.comp_node(), host_storage.size(),
host_storage.raw_storage());
m_dev_tensor->only_reset_raw_storage(device_storage);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册