提交 f7d2017e 编写于 作者: M Megvii Engine Team

fix(serialization): add tensor value loader support in new format

GitOrigin-RevId: e7da1d239669277e18d44d23c557abc39f2ac55f
上级 da7f250c
......@@ -513,13 +513,18 @@ void GraphDumperOSSV2::dump_tensor(
check_tensor_value_valid(name, tensor);
auto&& dumper = m_config.tensor_value_dumper;
if (dumper) {
mgb_log_warn(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value dumper callback.");
std::vector<uint8_t> out_vec;
auto temp_out_file = OutputFile::make_vector_proxy(&out_vec);
dumper(*temp_out_file, *m_cur_opr, tensor);
data = m_builder.CreateVector(
reinterpret_cast<uint8_t*>(out_vec.data()), out_vec.size());
m_cur_rst.tensor_value_bytes += out_vec.size();
} else {
data = m_builder.CreateVector(
reinterpret_cast<uint8_t*>(tensor.raw_ptr()),
layout.span().high_byte);
m_cur_rst.tensor_value_bytes += layout.span().high_byte;
}
data = m_builder.CreateVector(
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
m_cur_rst.tensor_value_bytes += layout.span().high_byte;
}
auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
......@@ -688,14 +693,9 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
if (tensor->data() && tensor->data()->size() > 0) {
if (loader) {
mgb_log_warn(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader callback.");
}
fill_tensor_memory(
*ret, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
m_loader->m_file->is_shared_memory(), loader);
}
if (tensor->name()) {
m_tensor_map[tensor->name()->str()] = ret;
......@@ -737,6 +737,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
shared_pair.first = tensor->name()->str();
}
auto loader = m_loader->m_cur_load_config->tensor_value_loader;
if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) {
// directly forward CPU memory
shared_tensor_ref = std::make_shared<DeviceTensorND>();
......@@ -745,7 +746,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
hv.dtype(layout.dtype).resize(layout);
fill_tensor_memory(
hv, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
m_loader->m_file->is_shared_memory(), loader);
}
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
*shared_tensor_ref = DeviceTensorND::make_proxy(hv);
......@@ -761,7 +762,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
hv.dtype(layout.dtype).resize(layout);
fill_tensor_memory(
hv, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
m_loader->m_file->is_shared_memory(), loader);
}
shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
}
......@@ -947,7 +948,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
if (m_shared_tensor_map.empty()) {
m_shared_tensor_map.resize(m_model->nr_shared_tensor());
} else {
mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
mgb_assert(m_shared_tensor_map.size() >= m_model->nr_shared_tensor());
}
SharedTensorAlignMent tensor_alignment(
m_model_buf, m_file.get(),
......
......@@ -154,20 +154,29 @@ public:
//! the memory used when load model, but should consider the memory
//! alignment
void fill_tensor_memory(
HostTensorND& tensor, const uint8_t* data, size_t size, bool shared) {
HostTensorND& tensor, const uint8_t* data, size_t size, bool shared,
GraphLoadConfig::TensorValueLoader loader) {
auto tensor_size = tensor.layout().span().high_byte;
mgb_assert(
size == tensor_size,
"the size is not match when shared the flatbuffer memory\n");
auto ptr = reinterpret_cast<void*>(const_cast<uint8_t*>(data));
if (shared) {
HostTensorStorage storage;
auto raw_storage = std::shared_ptr<mgb::dt_byte>(
static_cast<mgb::dt_byte*>(ptr), [](void*) {});
storage.reset(tensor.comp_node(), size, raw_storage);
tensor.reset(storage, tensor.layout());
if (loader) {
// call custom loader
void* dest_ptr = tensor.raw_ptr();
auto input_file = InputFile::make_mem_proxy(data, size);
loader(dest_ptr, tensor.layout(), *input_file);
} else {
memcpy(tensor.raw_ptr(), data, size);
mgb_assert(
size == tensor_size,
"the size is not match when shared the flatbuffer memory\n");
if (shared) {
HostTensorStorage storage;
auto raw_storage = std::shared_ptr<mgb::dt_byte>(
static_cast<mgb::dt_byte*>(ptr), [](void*) {});
storage.reset(tensor.comp_node(), size, raw_storage);
tensor.reset(storage, tensor.layout());
} else {
memcpy(tensor.raw_ptr(), data, size);
}
}
}
......
......@@ -315,8 +315,10 @@ void test_serializer_custom_loader(GraphDumpFormat format) {
load();
load();
ASSERT_EQ(2u, saved_val.size());
ASSERT_EQ(2, load_nr_null_ptr); // immutable tensor is also shared
ASSERT_EQ(4, load_nr_call);
if (GraphDumpFormat::FLATBUFFERS_V2 != format) {
ASSERT_EQ(2, load_nr_null_ptr); // immutable tensor is also shared
ASSERT_EQ(4, load_nr_call);
}
}
void test_serializer_many_io_var(GraphDumpFormat format) {
......@@ -998,6 +1000,10 @@ TEST(TestSerializer2, ManyIOVarsV2) {
test_serializer_many_io_var(GraphDumpFormat::FLATBUFFERS_V2);
}
TEST(TestSerializer2, CustomLoaderV2) {
test_serializer_custom_loader(GraphDumpFormat::FLATBUFFERS_V2);
}
TEST(TestSerializer2, RemoveSetGradV2) {
test_serializer_remove_set_grad(GraphDumpFormat::FLATBUFFERS_V2);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册