From f7d2017e2c2f49b352b365e63ab11e3626e6a345 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 25 Oct 2022 18:39:50 +0800 Subject: [PATCH] fix(serialization): add tensor value loader support in new format GitOrigin-RevId: e7da1d239669277e18d44d23c557abc39f2ac55f --- src/serialization/impl/serializer_oss_v2.cpp | 31 ++++++++++--------- .../serialization/oss_opr_load_dump.h | 31 ++++++++++++------- src/serialization/test/serializer_oss.cpp | 10 ++++-- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp index fbab8b2a1..246e11c5e 100644 --- a/src/serialization/impl/serializer_oss_v2.cpp +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -513,13 +513,18 @@ void GraphDumperOSSV2::dump_tensor( check_tensor_value_valid(name, tensor); auto&& dumper = m_config.tensor_value_dumper; if (dumper) { - mgb_log_warn( - "serialization v2 format is pure flatbuffer format, not support " - "user tensor value dumper callback."); + std::vector out_vec; + auto temp_out_file = OutputFile::make_vector_proxy(&out_vec); + dumper(*temp_out_file, *m_cur_opr, tensor); + data = m_builder.CreateVector( + reinterpret_cast(out_vec.data()), out_vec.size()); + m_cur_rst.tensor_value_bytes += out_vec.size(); + } else { + data = m_builder.CreateVector( + reinterpret_cast(tensor.raw_ptr()), + layout.span().high_byte); + m_cur_rst.tensor_value_bytes += layout.span().high_byte; } - data = m_builder.CreateVector( - reinterpret_cast(tensor.raw_ptr()), layout.span().high_byte); - m_cur_rst.tensor_value_bytes += layout.span().high_byte; } auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0; @@ -688,14 +693,9 @@ std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; if (tensor->data() && tensor->data()->size() > 0) { - if (loader) { - mgb_log_warn( - "serialization v2 format is pure flatbuffer format, not support " - "user tensor value loader callback."); - } fill_tensor_memory( *ret, tensor->data()->data(), tensor->data()->size(), - m_loader->m_file->is_shared_memory()); + m_loader->m_file->is_shared_memory(), loader); } if (tensor->name()) { m_tensor_map[tensor->name()->str()] = ret; @@ -737,6 +737,7 @@ std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl:: shared_pair.first = tensor->name()->str(); } + auto loader = m_loader->m_cur_load_config->tensor_value_loader; if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) { // directly forward CPU memory shared_tensor_ref = std::make_shared(); @@ -745,7 +746,7 @@ std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl:: hv.dtype(layout.dtype).resize(layout); fill_tensor_memory( hv, tensor->data()->data(), tensor->data()->size(), - m_loader->m_file->is_shared_memory()); + m_loader->m_file->is_shared_memory(), loader); } if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { *shared_tensor_ref = DeviceTensorND::make_proxy(hv); @@ -761,7 +762,7 @@ std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl:: hv.dtype(layout.dtype).resize(layout); fill_tensor_memory( hv, tensor->data()->data(), tensor->data()->size(), - m_loader->m_file->is_shared_memory()); + m_loader->m_file->is_shared_memory(), loader); } shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv)); } @@ -947,7 +948,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re if (m_shared_tensor_map.empty()) { m_shared_tensor_map.resize(m_model->nr_shared_tensor()); } else { - mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor()); + mgb_assert(m_shared_tensor_map.size() >= m_model->nr_shared_tensor()); } SharedTensorAlignMent tensor_alignment( m_model_buf, m_file.get(), diff --git a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h index 55386486a..76f554267 100644 --- a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h +++ b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h @@ -154,20 +154,29 @@ public: //! the memory used when load model, but should consider the memory //! alignment void fill_tensor_memory( - HostTensorND& tensor, const uint8_t* data, size_t size, bool shared) { + HostTensorND& tensor, const uint8_t* data, size_t size, bool shared, + GraphLoadConfig::TensorValueLoader loader) { auto tensor_size = tensor.layout().span().high_byte; - mgb_assert( - size == tensor_size, - "the size is not match when shared the flatbuffer memory\n"); + auto ptr = reinterpret_cast(const_cast(data)); - if (shared) { - HostTensorStorage storage; - auto raw_storage = std::shared_ptr( - static_cast(ptr), [](void*) {}); - storage.reset(tensor.comp_node(), size, raw_storage); - tensor.reset(storage, tensor.layout()); + if (loader) { + // call custom loader + void* dest_ptr = tensor.raw_ptr(); + auto input_file = InputFile::make_mem_proxy(data, size); + loader(dest_ptr, tensor.layout(), *input_file); } else { - memcpy(tensor.raw_ptr(), data, size); + mgb_assert( + size == tensor_size, + "the size is not match when shared the flatbuffer memory\n"); + if (shared) { + HostTensorStorage storage; + auto raw_storage = std::shared_ptr( + static_cast(ptr), [](void*) {}); + storage.reset(tensor.comp_node(), size, raw_storage); + tensor.reset(storage, tensor.layout()); + } else { + memcpy(tensor.raw_ptr(), data, size); + } } } diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index dfe825dfc..cdc8dbcfe 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -315,8 +315,10 @@ void test_serializer_custom_loader(GraphDumpFormat format) { load(); load(); ASSERT_EQ(2u, saved_val.size()); - ASSERT_EQ(2, load_nr_null_ptr); // immutable tensor is also shared - ASSERT_EQ(4, load_nr_call); + if (GraphDumpFormat::FLATBUFFERS_V2 != format) { + ASSERT_EQ(2, load_nr_null_ptr); // immutable tensor is also shared + ASSERT_EQ(4, load_nr_call); + } } void test_serializer_many_io_var(GraphDumpFormat format) { @@ -998,6 +1000,10 @@ TEST(TestSerializer2, ManyIOVarsV2) { test_serializer_many_io_var(GraphDumpFormat::FLATBUFFERS_V2); } +TEST(TestSerializer2, CustomLoaderV2) { + test_serializer_custom_loader(GraphDumpFormat::FLATBUFFERS_V2); +} + TEST(TestSerializer2, RemoveSetGradV2) { test_serializer_remove_set_grad(GraphDumpFormat::FLATBUFFERS_V2); } -- GitLab