提交 399200b3 编写于 作者: M Megvii Engine Team

perf(serialization): optimized the memory usage when load new format model

GitOrigin-RevId: 2b7313ebe39a7d4a44a8ab61fa0f3646fd7de566
上级 f31e52d5
...@@ -25,7 +25,11 @@ class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD { ...@@ -25,7 +25,11 @@ class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); } std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); } std::shared_ptr<DeviceTensorND> load_tensor_shared(
bool copy_immediatly = false) override {
(void)copy_immediatly;
mgb_assert(0);
}
const serialization::GraphLoadConfig& config() const override { mgb_assert(0); } const serialization::GraphLoadConfig& config() const override { mgb_assert(0); }
......
...@@ -245,9 +245,6 @@ int LITE_destroy_network(LiteNetwork network) { ...@@ -245,9 +245,6 @@ int LITE_destroy_network(LiteNetwork network) {
auto& global_holder = get_gloabl_network_holder(); auto& global_holder = get_gloabl_network_holder();
if (global_holder.find(network) != global_holder.end()) { if (global_holder.find(network) != global_holder.end()) {
global_holder.erase(network); global_holder.erase(network);
} else {
//! means the network has been destoryed
return -1;
} }
LITE_CAPI_END(); LITE_CAPI_END();
} }
......
...@@ -75,9 +75,6 @@ int LITE_destroy_tensor(LiteTensor tensor) { ...@@ -75,9 +75,6 @@ int LITE_destroy_tensor(LiteTensor tensor) {
auto& global_holder = get_global_tensor_holder(); auto& global_holder = get_global_tensor_holder();
if (global_holder.find(tensor) != global_holder.end()) { if (global_holder.find(tensor) != global_holder.end()) {
global_holder.erase(tensor); global_holder.erase(tensor);
} else {
//! return -1, means the tensor has been destroyed.
return -1;
} }
LITE_CAPI_END(); LITE_CAPI_END();
} }
......
...@@ -16,26 +16,8 @@ void ModelLite::create_network() { ...@@ -16,26 +16,8 @@ void ModelLite::create_network() {
} }
void ModelLite::load_model() { void ModelLite::load_model() {
if (share_model_mem) { //! lite shared memory default
//! WARNNING:maybe not right to share param memmory for this
LITE_LOG("enable share model memory");
FILE* fin = fopen(model_path.c_str(), "rb");
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size, "read model file failed");
fclose(fin);
m_network->load_model(buf.get(), size);
} else {
m_network->load_model(model_path); m_network->load_model(model_path);
}
} }
void ModelLite::run_model() { void ModelLite::run_model() {
......
...@@ -128,7 +128,10 @@ std::shared_ptr<void> ModelParser::decrypt_memory( ...@@ -128,7 +128,10 @@ std::shared_ptr<void> ModelParser::decrypt_memory(
const uint8_t* memory_ptr = data; const uint8_t* memory_ptr = data;
if (decryption_name == "NONE") { if (decryption_name == "NONE") {
result_length = length; result_length = length;
return std::shared_ptr<void>(const_cast<uint8_t*>(memory_ptr), [](void*) {}); std::shared_ptr<uint8_t> shptr{
new uint8_t[length], [](uint8_t* p) { delete[] p; }};
memcpy(shptr.get(), data, length);
return shptr;
} }
LITE_LOCK_GUARD(decryption_static_data().map_mutex); LITE_LOCK_GUARD(decryption_static_data().map_mutex);
auto it = decryption_static_data().decryption_methods.find(decryption_name); auto it = decryption_static_data().decryption_methods.find(decryption_name);
......
...@@ -1032,7 +1032,7 @@ TEST(TestCapiNetWork, GlobalHolder) { ...@@ -1032,7 +1032,7 @@ TEST(TestCapiNetWork, GlobalHolder) {
LITE_make_network(&c_network, *default_config(), *default_network_io())); LITE_make_network(&c_network, *default_config(), *default_network_io()));
//! make sure destroy_network is destroyed by LITE_make_network //! make sure destroy_network is destroyed by LITE_make_network
LITE_destroy_network(destroy_network); LITE_destroy_network(destroy_network);
ASSERT_EQ(LITE_destroy_network(destroy_network), -1); ASSERT_EQ(LITE_destroy_network(destroy_network), 0);
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); LITE_CAPI_CHECK(LITE_destroy_network(c_network));
} }
......
...@@ -328,7 +328,7 @@ TEST(TestCapiTensor, GlobalHolder) { ...@@ -328,7 +328,7 @@ TEST(TestCapiTensor, GlobalHolder) {
LITE_make_tensor(description, &c_tensor0); LITE_make_tensor(description, &c_tensor0);
//! make sure destroy_tensor is destroyed by LITE_make_tensor //! make sure destroy_tensor is destroyed by LITE_make_tensor
LITE_destroy_tensor(destroy_tensor); LITE_destroy_tensor(destroy_tensor);
ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1); ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), 0);
LITE_destroy_tensor(c_tensor0); LITE_destroy_tensor(c_tensor0);
} }
......
...@@ -332,6 +332,7 @@ class TensorND { ...@@ -332,6 +332,7 @@ class TensorND {
public: public:
using ChainReturnType = TensorND<TensorStorage>; using ChainReturnType = TensorND<TensorStorage>;
using Storage = TensorStorage;
MGE_WIN_DECLSPEC_FUC TensorND(); MGE_WIN_DECLSPEC_FUC TensorND();
......
...@@ -443,18 +443,19 @@ void run<shape_dep_const_shape>(CompNode cn) { ...@@ -443,18 +443,19 @@ void run<shape_dep_const_shape>(CompNode cn) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto host_x = gen({4, 5}, cn); auto host_x = gen({4, 5}, cn);
auto fname = output_file("test_comp_node_record_shape_dep_const_shape"); auto fname = output_file("test_comp_node_record_shape_dep_const_shape");
auto test = [&](serialization::GraphDumpFormat format) {
HostTensorND y_expect; HostTensorND y_expect;
{ {
// dump graph // dump graph
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, OperatorNodeConfig{"x"}), auto x = opr::Host2DeviceCopy::make(
*graph, host_x, OperatorNodeConfig{"x"}),
y = x.flatten() + y = x.flatten() +
opr::reduce_sum(opr::GetVarShape::make(x), x.make_scalar(1)); opr::reduce_sum(opr::GetVarShape::make(x), x.make_scalar(1));
graph->compile({make_callback_copy(y, y_expect)})->execute(); graph->compile({make_callback_copy(y, y_expect)})->execute();
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), format);
dumper->dump({y}); dumper->dump({y});
} }
...@@ -462,7 +463,7 @@ void run<shape_dep_const_shape>(CompNode cn) { ...@@ -462,7 +463,7 @@ void run<shape_dep_const_shape>(CompNode cn) {
{ {
GraphLoadConfig config; GraphLoadConfig config;
config.const_var_shape = true; config.const_var_shape = true;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), format);
auto load_rst = loader->load(config); auto load_rst = loader->load(config);
load_rst.graph->options().comp_node_seq_record_level = 2; load_rst.graph->options().comp_node_seq_record_level = 2;
load_rst.graph->options().var_sanity_check_first_run = false; load_rst.graph->options().var_sanity_check_first_run = false;
...@@ -473,8 +474,10 @@ void run<shape_dep_const_shape>(CompNode cn) { ...@@ -473,8 +474,10 @@ void run<shape_dep_const_shape>(CompNode cn) {
x_inp->copy_from(*host_x); x_inp->copy_from(*host_x);
func->execute(); func->execute();
} }
MGB_ASSERT_TENSOR_EQ(y_expect, host_y); MGB_ASSERT_TENSOR_EQ(y_expect, host_y);
};
test({});
test(serialization::GraphDumpFormat::FLATBUFFERS_V2);
} }
//! single thread multi recorder run interleave //! single thread multi recorder run interleave
......
...@@ -367,16 +367,19 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ImmutableTensor); ...@@ -367,16 +367,19 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ImmutableTensor);
class ImmutableTensor::Value { class ImmutableTensor::Value {
MGB_MUTEX m_mtx; MGB_MUTEX m_mtx;
DeviceTensorND m_dev, m_static_infer; std::shared_ptr<DeviceTensorND> m_dev = std::make_shared<DeviceTensorND>();
DeviceTensorND m_static_infer;
std::string m_summary; std::string m_summary;
public: public:
void setup(CompNode cn, const HostTensorND& val); void setup(CompNode cn, const HostTensorND& val);
bool initialized() const { return m_dev.shape_valid(); } void setup(std::shared_ptr<DeviceTensorND> val);
bool initialized() const { return m_dev->shape_valid(); }
//! value on comp node //! value on comp node
const DeviceTensorND& dev() const { return m_dev; } const DeviceTensorND& dev() const { return *m_dev; }
//! get value on static infer CPU node //! get value on static infer CPU node
DeviceTensorND& static_infer(); DeviceTensorND& static_infer();
...@@ -385,10 +388,17 @@ public: ...@@ -385,10 +388,17 @@ public:
const std::string& summary() const { return m_summary; } const std::string& summary() const { return m_summary; }
}; };
void ImmutableTensor::Value::setup(std::shared_ptr<DeviceTensorND> val) {
mgb_assert(val);
m_dev = val;
m_summary = ssprintf("const%s", val->shape().to_string().c_str());
}
void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) { void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) {
mgb_assert(m_dev.empty() && !m_dev.shape_valid()); mgb_assert(m_dev->empty() && !m_dev->shape_valid());
m_dev.comp_node(cn).copy_from(val).sync();
mgb_assert(val.empty() == m_dev.empty()); m_dev->comp_node(cn).copy_from(val).sync();
mgb_assert(val.empty() == m_dev->empty());
auto one_elem = [](const TensorShape& shape) { auto one_elem = [](const TensorShape& shape) {
for (size_t i = 0; i < shape.ndim; ++i) { for (size_t i = 0; i < shape.ndim; ++i) {
...@@ -413,8 +423,8 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) { ...@@ -413,8 +423,8 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) {
DeviceTensorND& ImmutableTensor::Value::static_infer() { DeviceTensorND& ImmutableTensor::Value::static_infer() {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
if (!m_static_infer.shape_valid()) { if (!m_static_infer.shape_valid()) {
mgb_assert(m_dev.shape_valid()); mgb_assert(m_dev->shape_valid());
m_static_infer.comp_node(CompNode::default_cpu()).copy_from(m_dev); m_static_infer.comp_node(CompNode::default_cpu()).copy_from(*m_dev);
} }
return m_static_infer; return m_static_infer;
} }
...@@ -588,6 +598,19 @@ SymbolVar ImmutableTensor::make( ...@@ -588,6 +598,19 @@ SymbolVar ImmutableTensor::make(
return make_from_value(graph, cache.get(val), {}, config); return make_from_value(graph, cache.get(val), {}, config);
} }
SymbolVar ImmutableTensor::make(
ComputingGraph& graph, std::shared_ptr<DeviceTensorND> val,
const OperatorNodeConfig& config) {
auto cn = val->comp_node();
if (config.has_comp_node_set())
cn = config.get_single_comp_node();
auto value = std::make_shared<Value>();
value->setup(val);
return make_from_value(graph, *value, value, config);
}
SymbolVar ImmutableTensor::make( SymbolVar ImmutableTensor::make(
ComputingGraph& graph, const DTypeScalar& val, ComputingGraph& graph, const DTypeScalar& val,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
......
...@@ -132,8 +132,10 @@ struct OprLoadDumpImpl<opr::ImmutableTensor, 0> { ...@@ -132,8 +132,10 @@ struct OprLoadDumpImpl<opr::ImmutableTensor, 0> {
OprLoadContext& ctx, const cg::VarNodeArray& inputs, OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
mgb_assert(inputs.empty()); mgb_assert(inputs.empty());
auto val = ctx.load_tensor(); //! because ImmutableTensor will used in infer shape or infer value,
return Opr::make(ctx.graph(), *val, config).node()->owner_opr(); //! so must copy immediatly
auto val = ctx.load_tensor_shared(true);
return Opr::make(ctx.graph(), val, config).node()->owner_opr();
} }
}; };
......
...@@ -32,8 +32,10 @@ struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> { ...@@ -32,8 +32,10 @@ struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> {
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data()); fbs_ctx.get_current_opr_data());
if (fopr->tensors() && fopr->tensors()->size() > 0) { if (fopr->tensors() && fopr->tensors()->size() > 0) {
auto val = fbs_ctx.load_tensor(); //! because ImmutableTensor will used in infer shape or infer value,
return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr(); //! so must copy immediatly
auto val = fbs_ctx.load_tensor_shared(true);
return Opr::make(fbs_ctx.graph(), val, config).node()->owner_opr();
} else { } else {
mgb_throw(SerializationError, "ImmutableTensor load with no tensor data."); mgb_throw(SerializationError, "ImmutableTensor load with no tensor data.");
} }
......
...@@ -360,6 +360,10 @@ public: ...@@ -360,6 +360,10 @@ public:
ComputingGraph& graph, const HostTensorND& val, ComputingGraph& graph, const HostTensorND& val,
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, std::shared_ptr<DeviceTensorND> val,
const OperatorNodeConfig& config = {});
//! make from DTypeScalar; comp node must be provided in config //! make from DTypeScalar; comp node must be provided in config
MGE_WIN_DECLSPEC_FUC static SymbolVar make( MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const DTypeScalar& val, ComputingGraph& graph, const DTypeScalar& val,
......
...@@ -138,6 +138,10 @@ public: ...@@ -138,6 +138,10 @@ public:
mgb_assert(m_refhold && size); mgb_assert(m_refhold && size);
} }
bool is_shared_memory() override { return true; }
bool writable() override { return m_writable; }
void have_modified() override { m_modified = true; }
void rewind() override { void rewind() override {
if (m_modified) { if (m_modified) {
// data has beem modified; can not read again // data has beem modified; can not read again
......
...@@ -63,7 +63,11 @@ class OprLoadContextMemory final : public OprLoadContextRawPOD { ...@@ -63,7 +63,11 @@ class OprLoadContextMemory final : public OprLoadContextRawPOD {
std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); } std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); } std::shared_ptr<DeviceTensorND> load_tensor_shared(
bool copy_immediatly = false) override {
(void)copy_immediatly;
mgb_assert(0);
}
const GraphLoadConfig& config() const override { const GraphLoadConfig& config() const override {
mgb_throw(GraphError, "OprLoadContextMemory has no associated config"); mgb_throw(GraphError, "OprLoadContextMemory has no associated config");
......
...@@ -483,7 +483,8 @@ class GraphLoaderOSS::OprLoadContextImpl final : public OprLoadContextFlatBuffer ...@@ -483,7 +483,8 @@ class GraphLoaderOSS::OprLoadContextImpl final : public OprLoadContextFlatBuffer
std::shared_ptr<HostTensorND> load_tensor() override; std::shared_ptr<HostTensorND> load_tensor() override;
std::shared_ptr<DeviceTensorND> load_tensor_shared() override; std::shared_ptr<DeviceTensorND> load_tensor_shared(
bool copy_immediatly = false) override;
void load_single_opr(const fbs::Operator* opr); void load_single_opr(const fbs::Operator* opr);
...@@ -641,8 +642,8 @@ std::shared_ptr<HostTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor() ...@@ -641,8 +642,8 @@ std::shared_ptr<HostTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor()
return ret; return ret;
} }
std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl:: std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared(
load_tensor_shared() { bool copy_immediatly) {
mgb_assert( mgb_assert(
m_current_opr->tensors() && m_current_opr->tensors() &&
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
...@@ -650,6 +651,9 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl:: ...@@ -650,6 +651,9 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::
auto comp_node = load_comp_node(tensor->comp_node()); auto comp_node = load_comp_node(tensor->comp_node());
auto layout = load_tensor_layout(tensor); auto layout = load_tensor_layout(tensor);
mgb_assert(tensor->data_size()); mgb_assert(tensor->data_size());
if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) {
m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5);
}
auto&& sh_reg = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++); auto&& sh_reg = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
auto&& sh_ptr_ref = sh_reg.second[comp_node.mem_node()]; auto&& sh_ptr_ref = sh_reg.second[comp_node.mem_node()];
if (sh_ptr_ref) { if (sh_ptr_ref) {
...@@ -673,6 +677,11 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl:: ...@@ -673,6 +677,11 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::
load_tensor_value(&hv, layout, tensor); load_tensor_value(&hv, layout, tensor);
sh_ptr_ref = std::make_shared<DeviceTensorND>(); sh_ptr_ref = std::make_shared<DeviceTensorND>();
*sh_ptr_ref = DeviceTensorND::make_proxy(hv); *sh_ptr_ref = DeviceTensorND::make_proxy(hv);
} else if (copy_immediatly) {
HostTensorND hv{CompNode::default_cpu()};
load_tensor_value(&hv, layout, tensor);
sh_ptr_ref = std::make_shared<DeviceTensorND>();
sh_ptr_ref->comp_node(comp_node).copy_from(hv).sync();
} else { } else {
// use lazy load for non-CPU devices // use lazy load for non-CPU devices
HostTensorND hv{CompNode::default_cpu()}; HostTensorND hv{CompNode::default_cpu()};
...@@ -803,7 +812,7 @@ GraphLoader::LoadResult GraphLoaderOSS::OprLoadContextImpl::load_oprs() { ...@@ -803,7 +812,7 @@ GraphLoader::LoadResult GraphLoaderOSS::OprLoadContextImpl::load_oprs() {
ret.output_var_map_id[out->original_id()] = var; ret.output_var_map_id[out->original_id()] = var;
ret.output_var_list[i] = var; ret.output_var_list[i] = var;
} }
mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size()); mgb_assert(m_cur_shared_tensor_idx <= m_loader->m_shared_tensor_map.size());
return ret; return ret;
} }
...@@ -880,7 +889,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi ...@@ -880,7 +889,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
if (m_shared_tensor_map.empty()) { if (m_shared_tensor_map.empty()) {
m_shared_tensor_map.resize(m_graph->nr_shared_tensor()); m_shared_tensor_map.resize(m_graph->nr_shared_tensor());
} else { } else {
mgb_assert(m_shared_tensor_map.size() == m_graph->nr_shared_tensor()); mgb_assert(m_shared_tensor_map.size() >= m_graph->nr_shared_tensor());
} }
OprLoadContextImpl ctx{this, m_graph->mgb_version()}; OprLoadContextImpl ctx{this, m_graph->mgb_version()};
......
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
#include <map>
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/serialization/helper.h" #include "megbrain/serialization/helper.h"
...@@ -523,6 +524,77 @@ void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) { ...@@ -523,6 +524,77 @@ void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) {
} }
// ----------------------------- Loader -------------------------------------- // ----------------------------- Loader --------------------------------------
/**
* SharedTensorAlignMent will record all shared device tensors, at beginning, the
* tensor is not aligned, after all shared device tensor loaded, and the user
* provide memory will be wrote, and reorder all the tensor to aligned address
* ptr.
*/
class GraphLoaderOSSV2::SharedTensorAlignMent {
public:
SharedTensorAlignMent(SharedBuffer buffer, InputFile* file, bool is_enabled)
: m_enabled(is_enabled), m_file(file), m_model_buffer(buffer){};
bool add_device_tensor(std::shared_ptr<DeviceTensorND> tensor) {
if (!m_enabled)
return false;
if (tensor) {
m_device_tensors[reinterpret_cast<intptr_t>(tensor->raw_ptr())] = tensor;
return true;
}
return false;
}
/**
* record the tensor shared from the m_model_buffer, copy every tensor to
* the aligned address, then the model file will be modilfied, so it can't
* use again.
*/
bool reorder_and_align_tensor() {
if (!m_enabled)
return false;
bool modilfied = false;
intptr_t buffer_start = reinterpret_cast<intptr_t>(m_model_buffer.data());
intptr_t write_end = buffer_start;
for (auto& iter : m_device_tensors) {
auto& tensor = iter.second;
size_t tensor_size = tensor->layout().span().dist_byte();
size_t alignment = tensor->comp_node().get_mem_addr_alignment();
intptr_t tensor_start = reinterpret_cast<intptr_t>(tensor->raw_ptr());
intptr_t align_start = static_cast<intptr_t>(
reinterpret_cast<uintptr_t>(tensor->raw_ptr()) & ~(alignment - 1));
if (align_start > write_end) {
if (tensor_start != align_start) {
memmove(reinterpret_cast<void*>(align_start),
reinterpret_cast<void*>(tensor_start), tensor_size);
modilfied = true;
}
write_end = align_start + tensor_size;
DeviceTensorStorage storage;
auto raw_storage = std::shared_ptr<mgb::dt_byte>(
reinterpret_cast<mgb::dt_byte*>(align_start), [](void*) {});
storage.reset(tensor->comp_node(), tensor_size, raw_storage);
tensor->reset(storage, tensor->layout());
} else {
DeviceTensorND new_tensor(tensor->comp_node());
new_tensor.copy_from(*tensor).sync();
*tensor = new_tensor;
}
if (modilfied) {
m_file->have_modified();
}
}
return true;
}
private:
bool m_enabled = false;
InputFile* m_file;
SharedBuffer m_model_buffer;
std::map<intptr_t, std::shared_ptr<DeviceTensorND>> m_device_tensors;
};
CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node( CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
const fbs::v2::CompNode* comp_node) { const fbs::v2::CompNode* comp_node) {
mgb_assert(comp_node); mgb_assert(comp_node);
...@@ -596,7 +668,9 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( ...@@ -596,7 +668,9 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
"serialization v2 format is pure flatbuffer format, not support " "serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader callback."); "user tensor value loader callback.");
} }
memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size()); fill_tensor_memory(
*ret, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
} }
if (tensor->name()) { if (tensor->name()) {
m_tensor_map[tensor->name()->str()] = ret; m_tensor_map[tensor->name()->str()] = ret;
...@@ -612,7 +686,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( ...@@ -612,7 +686,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
} }
std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
load_tensor_shared() { load_tensor_shared(bool copy_immediatly) {
mgb_assert( mgb_assert(
m_current_opr->tensors() && m_current_opr->tensors() &&
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
...@@ -620,6 +694,9 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: ...@@ -620,6 +694,9 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
auto comp_node = load_comp_node(tensor->comp_node()); auto comp_node = load_comp_node(tensor->comp_node());
auto layout = load_tensor_layout(tensor, comp_node); auto layout = load_tensor_layout(tensor, comp_node);
mgb_assert(tensor->data()); mgb_assert(tensor->data());
if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) {
m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5);
}
auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++); auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()]; auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()];
if (shared_tensor_ref) { if (shared_tensor_ref) {
...@@ -637,19 +714,34 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: ...@@ -637,19 +714,34 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
// directly forward CPU memory // directly forward CPU memory
shared_tensor_ref = std::make_shared<DeviceTensorND>();
HostTensorND hv{comp_node}; HostTensorND hv{comp_node};
if (tensor->data() && tensor->data()->size() > 0) { if (tensor->data() && tensor->data()->size() > 0) {
hv.dtype(layout.dtype).resize(layout); hv.dtype(layout.dtype).resize(layout);
memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size()); fill_tensor_memory(
hv, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
} }
shared_tensor_ref = std::make_shared<DeviceTensorND>();
*shared_tensor_ref = DeviceTensorND::make_proxy(hv); *shared_tensor_ref = DeviceTensorND::make_proxy(hv);
m_tensor_alignment->add_device_tensor(shared_tensor_ref);
} else if (copy_immediatly) {
HostTensorND hv{CompNode::default_cpu()};
shared_tensor_ref = std::make_shared<DeviceTensorND>();
if (tensor->data() && tensor->data()->size() > 0) {
hv.dtype(layout.dtype).resize(layout);
fill_tensor_memory(
hv, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
}
shared_tensor_ref->comp_node(comp_node).copy_from(hv).sync();
} else { } else {
// use lazy load for non-CPU devices // use lazy load for non-CPU devices
HostTensorND hv{CompNode::default_cpu()}; HostTensorND hv{CompNode::default_cpu()};
if (tensor->data() && tensor->data()->size() > 0) { if (tensor->data() && tensor->data()->size() > 0) {
hv.dtype(layout.dtype).resize(layout); hv.dtype(layout.dtype).resize(layout);
memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size()); fill_tensor_memory(
hv, tensor->data()->data(), tensor->data()->size(),
m_loader->m_file->is_shared_memory());
} }
shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv)); shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
} }
...@@ -784,7 +876,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() { ...@@ -784,7 +876,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() {
ret.output_var_map_id[out->original_id()] = var; ret.output_var_map_id[out->original_id()] = var;
ret.output_var_list[i] = var; ret.output_var_list[i] = var;
} }
mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size()); mgb_assert(m_cur_shared_tensor_idx <= m_loader->m_shared_tensor_map.size());
return ret; return ret;
} }
...@@ -808,7 +900,6 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re ...@@ -808,7 +900,6 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
m_file->read(&size, sizeof(size)); m_file->read(&size, sizeof(size));
m_file->skip(-sizeof(size)); m_file->skip(-sizeof(size));
m_model_buf = m_file->read_shared(size + sizeof(size)); m_model_buf = m_file->read_shared(size + sizeof(size));
{ {
flatbuffers::Verifier verifier( flatbuffers::Verifier verifier(
static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size()); static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
...@@ -838,8 +929,10 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re ...@@ -838,8 +929,10 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
} else { } else {
mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor()); mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
} }
SharedTensorAlignMent tensor_alignment(
OprLoadContextImpl ctx{this, m_model->mge_version()}; m_model_buf, m_file.get(),
m_file->writable() && m_file->is_shared_memory());
OprLoadContextImpl ctx{this, &tensor_alignment, m_model->mge_version()};
ctx.load_middle_tensor(); ctx.load_middle_tensor();
auto metadata = ctx.load_metadata(); auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs(); auto result = ctx.load_oprs();
...@@ -856,6 +949,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re ...@@ -856,6 +949,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
} }
} }
m_model_loaded = true; m_model_loaded = true;
tensor_alignment.reorder_and_align_tensor();
result.graph_compile_ahead(); result.graph_compile_ahead();
return result; return result;
} }
......
...@@ -41,6 +41,15 @@ public: ...@@ -41,6 +41,15 @@ public:
//! return current read offset //! return current read offset
virtual size_t tell() = 0; virtual size_t tell() = 0;
//! whether this file format support share memory when load model
virtual bool is_shared_memory() { return false; }
//! whether this can be write
virtual bool writable() { return false; }
//! whether this file have been wrote
virtual void have_modified() {}
/*! /*!
* \brief read into a host tensor * \brief read into a host tensor
* *
......
...@@ -208,7 +208,8 @@ public: ...@@ -208,7 +208,8 @@ public:
* *
* It must be dumped with TensorWriteMethod::VALUE_SHARED * It must be dumped with TensorWriteMethod::VALUE_SHARED
*/ */
virtual std::shared_ptr<DeviceTensorND> load_tensor_shared() = 0; virtual std::shared_ptr<DeviceTensorND> load_tensor_shared(
bool copy_immediatly = false) = 0;
//! get associated global configuration //! get associated global configuration
virtual const GraphLoadConfig& config() const = 0; virtual const GraphLoadConfig& config() const = 0;
......
...@@ -104,6 +104,7 @@ class GraphLoaderOSSV2 final : public GraphLoader { ...@@ -104,6 +104,7 @@ class GraphLoaderOSSV2 final : public GraphLoader {
public: public:
class OprLoadContextImpl; class OprLoadContextImpl;
class SharedTensorAlignMent;
friend class OprLoadContextImpl; friend class OprLoadContextImpl;
GraphLoaderOSSV2(std::unique_ptr<InputFile> input_file) GraphLoaderOSSV2(std::unique_ptr<InputFile> input_file)
...@@ -136,22 +137,51 @@ class GraphLoaderOSSV2::OprLoadContextImpl final : public OprLoadContextFlatBuff ...@@ -136,22 +137,51 @@ class GraphLoaderOSSV2::OprLoadContextImpl final : public OprLoadContextFlatBuff
size_t m_cur_opr_tensor_cnt; size_t m_cur_opr_tensor_cnt;
size_t m_cur_opr_blob_cnt; size_t m_cur_opr_blob_cnt;
size_t m_cur_opr_param_cnt; size_t m_cur_opr_param_cnt;
SharedTensorAlignMent* m_tensor_alignment;
public: public:
friend class SharedTensorAlignMent;
ComputingGraph& graph() override { return *m_graph; } ComputingGraph& graph() override { return *m_graph; }
const GraphLoadConfig& config() const override { const GraphLoadConfig& config() const override {
return *m_loader->m_cur_load_config; return *m_loader->m_cur_load_config;
} }
//! shared or copy the loaded flatbuffer memory to the CPU tensor, this can reduce
//! the memory used when load model, but should consider the memory
//! alignment
void fill_tensor_memory(
HostTensorND& tensor, const uint8_t* data, size_t size, bool shared) {
auto tensor_size = tensor.layout().span().high_byte;
mgb_assert(
size == tensor_size,
"the size is not match when shared the flatbuffer memory\n");
auto ptr = reinterpret_cast<void*>(const_cast<uint8_t*>(data));
if (shared) {
HostTensorStorage storage;
auto raw_storage = std::shared_ptr<mgb::dt_byte>(
static_cast<mgb::dt_byte*>(ptr), [](void*) {});
storage.reset(tensor.comp_node(), size, raw_storage);
tensor.reset(storage, tensor.layout());
} else {
memcpy(tensor.raw_ptr(), data, size);
}
}
std::shared_ptr<HostTensorND> load_tensor() override; std::shared_ptr<HostTensorND> load_tensor() override;
std::shared_ptr<DeviceTensorND> load_tensor_shared() override; std::shared_ptr<DeviceTensorND> load_tensor_shared(
bool copy_immediatly = false) override;
void load_single_opr(const fbs::v2::Operator* opr); void load_single_opr(const fbs::v2::Operator* opr);
OprLoadContextImpl(GraphLoaderOSSV2* loader, uint32_t version) OprLoadContextImpl(
: OprLoadContextFlatBuffers(version), m_loader{loader} { GraphLoaderOSSV2* loader, SharedTensorAlignMent* tensor_alignment,
uint32_t version)
: OprLoadContextFlatBuffers(version),
m_loader{loader},
m_tensor_alignment(tensor_alignment) {
m_graph = loader->m_cur_load_config->comp_graph; m_graph = loader->m_cur_load_config->comp_graph;
if (!m_graph) { if (!m_graph) {
m_graph = ComputingGraph::make(); m_graph = ComputingGraph::make();
......
...@@ -315,7 +315,7 @@ void test_serializer_custom_loader(GraphDumpFormat format) { ...@@ -315,7 +315,7 @@ void test_serializer_custom_loader(GraphDumpFormat format) {
load(); load();
load(); load();
ASSERT_EQ(2u, saved_val.size()); ASSERT_EQ(2u, saved_val.size());
ASSERT_EQ(1, load_nr_null_ptr); // immutable tensor is not shared ASSERT_EQ(2, load_nr_null_ptr); // immutable tensor is also shared
ASSERT_EQ(4, load_nr_call); ASSERT_EQ(4, load_nr_call);
} }
...@@ -482,10 +482,10 @@ void test_serializer_multiple_param(GraphDumpFormat format) { ...@@ -482,10 +482,10 @@ void test_serializer_multiple_param(GraphDumpFormat format) {
ASSERT_THROW(loader->shared_tensor_id_map(), MegBrainError); ASSERT_THROW(loader->shared_tensor_id_map(), MegBrainError);
loader->load(); loader->load();
auto&& got = loader->shared_tensor_id_map(); auto&& got = loader->shared_tensor_id_map();
ASSERT_EQ(values.size(), got.size()); ASSERT_EQ(2 * values.size(), got.size());
for (size_t i = 0; i < values.size(); ++i) { for (size_t i = 0; i < values.size(); ++i) {
ASSERT_EQ(1u, got[i].second.size()); ASSERT_EQ(1u, got[i].second.size());
auto &&vi = *values[i], &&gi = *got[i].second.begin()->second; auto &&vi = *values[i], &&gi = *got[2 * i].second.begin()->second;
ASSERT_EQ(vi.shape(), gi.shape()); ASSERT_EQ(vi.shape(), gi.shape());
ASSERT_EQ(vi.comp_node(), gi.comp_node()); ASSERT_EQ(vi.comp_node(), gi.comp_node());
ASSERT_EQ(vi.dtype(), gi.dtype()); ASSERT_EQ(vi.dtype(), gi.dtype());
...@@ -565,7 +565,7 @@ void test_serializer_const_var_shape(GraphDumpFormat format) { ...@@ -565,7 +565,7 @@ void test_serializer_const_var_shape(GraphDumpFormat format) {
} }
}; };
run_and_check(config); run_and_check(config);
ASSERT_EQ(2, nr_tensor); ASSERT_EQ(1, nr_tensor); // immutable tensor is shared tensor
ASSERT_EQ(1, nr_mod); ASSERT_EQ(1, nr_mod);
} }
} }
...@@ -823,6 +823,77 @@ void test_serializer_log_exp(GraphDumpFormat format) { ...@@ -823,6 +823,77 @@ void test_serializer_log_exp(GraphDumpFormat format) {
load(); load();
} }
void test_serializer_memshare(GraphDumpFormat format) {
std::vector<uint8_t> buf;
HostTensorGenerator<> gen;
constexpr size_t SIZE = 127;
auto xval = gen({SIZE}, "cpu0"), bval = gen({1}, "cpu0");
auto dump = [&]() {
auto graph = ComputingGraph::make();
auto x0 = opr::SharedDeviceTensor::make(*graph, *xval).rename("x0");
auto x1 = opr::SharedDeviceTensor::make(*graph, *xval).rename("x1");
auto x2 = opr::SharedDeviceTensor::make(*graph, *xval).rename("x2");
auto x3 = opr::SharedDeviceTensor::make(*graph, *xval).rename("x3");
auto i4 = opr::ImmutableTensor::make(*graph, *xval).rename("i4");
auto i5 = opr::ImmutableTensor::make(*graph, *xval).rename("i5");
auto b = opr::SharedDeviceTensor::make(*graph, *bval).rename("b");
auto dumper = GraphDumper::make(OutputFile::make_vector_proxy(&buf), format);
dumper->dump({((x0 + x1) + b) + (x2 + x3) + i4 + i5, x0, i4});
};
HostTensorND expected;
expected.copy_from(*xval);
for (size_t i = 0; i < SIZE; ++i) {
auto&& v = expected.ptr<float>()[i];
v = v * 6 + bval->ptr<float>()[0];
}
std::vector<uint8_t> buf_al;
auto load = [&](bool share) {
std::unique_ptr<InputFile> fin;
if (share) {
buf_al.resize(buf.size());
memcpy(buf_al.data(), buf.data(), buf.size());
fin = InputFile::make_mem_proxy(
std::shared_ptr<void>{std::shared_ptr<void>{}, buf_al.data()},
buf.size());
} else {
fin = InputFile::make_mem_proxy(buf.data(), buf.size());
}
auto loader = GraphLoader::make(std::move(fin), format);
auto rst = loader->load();
auto x = rst.output_var_map.at("x0");
auto i4 = rst.output_var_map.at("i4");
auto&& opr = x.node()->owner_opr()->cast_final_safe<opr::SharedDeviceTensor>();
auto&& opr_imm =
i4.node()->owner_opr()->cast_final_safe<opr::ImmutableTensor>();
HostTensorND val;
auto func =
rst.graph_compile({make_callback_copy(rst.output_var_list[0], val)});
func->execute();
return std::make_pair(
val, std::vector<DeviceTensorND>{*opr.dev_data(), opr_imm.value()});
};
auto in_range = [](const std::vector<uint8_t>& buf, DeviceTensorND& dv) {
auto p0 = reinterpret_cast<uint8_t*>(dv.raw_ptr()),
p1 = reinterpret_cast<uint8_t*>(p0 + dv.layout().span().high_byte);
return buf.data() <= p0 && p1 <= buf.data() + buf.size();
};
for (bool share : {false, true}) {
buf.clear();
dump();
auto get = load(share);
MGB_ASSERT_TENSOR_EQ(*xval, HostTensorND{}.copy_from(get.second[0]).sync());
MGB_ASSERT_TENSOR_EQ(expected, get.first);
ASSERT_EQ(share, in_range(buf_al, get.second[0]));
ASSERT_EQ(share, in_range(buf_al, get.second[1]));
}
}
} // namespace } // namespace
TEST(TestSerializer2, GraphDumpLoad) { TEST(TestSerializer2, GraphDumpLoad) {
...@@ -967,6 +1038,10 @@ TEST(TestSerializer2, LOGEXPV2) { ...@@ -967,6 +1038,10 @@ TEST(TestSerializer2, LOGEXPV2) {
test_serializer_log_exp(GraphDumpFormat::FLATBUFFERS_V2); test_serializer_log_exp(GraphDumpFormat::FLATBUFFERS_V2);
} }
TEST(TestSerializer2, ShareMemv2) {
test_serializer_memshare(GraphDumpFormat::FLATBUFFERS_V2);
}
TEST(TestSerializer2, TestSoftMaxLoadDump) { TEST(TestSerializer2, TestSoftMaxLoadDump) {
auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2); auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2);
TensorShape shape{2, 3}; TensorShape shape{2, 3};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册