Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f7d2017e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f7d2017e
编写于
10月 25, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(serialization): add tensor value loader support in new format
GitOrigin-RevId: e7da1d239669277e18d44d23c557abc39f2ac55f
上级
da7f250c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
44 addition
and
28 deletion
+44
-28
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+16
-15
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
...zation/include/megbrain/serialization/oss_opr_load_dump.h
+20
-11
src/serialization/test/serializer_oss.cpp
src/serialization/test/serializer_oss.cpp
+8
-2
未找到文件。
src/serialization/impl/serializer_oss_v2.cpp
浏览文件 @
f7d2017e
...
...
@@ -513,13 +513,18 @@ void GraphDumperOSSV2::dump_tensor(
check_tensor_value_valid
(
name
,
tensor
);
auto
&&
dumper
=
m_config
.
tensor_value_dumper
;
if
(
dumper
)
{
mgb_log_warn
(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value dumper callback."
);
std
::
vector
<
uint8_t
>
out_vec
;
auto
temp_out_file
=
OutputFile
::
make_vector_proxy
(
&
out_vec
);
dumper
(
*
temp_out_file
,
*
m_cur_opr
,
tensor
);
data
=
m_builder
.
CreateVector
(
reinterpret_cast
<
uint8_t
*>
(
out_vec
.
data
()),
out_vec
.
size
());
m_cur_rst
.
tensor_value_bytes
+=
out_vec
.
size
();
}
else
{
data
=
m_builder
.
CreateVector
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
raw_ptr
()),
layout
.
span
().
high_byte
);
m_cur_rst
.
tensor_value_bytes
+=
layout
.
span
().
high_byte
;
}
data
=
m_builder
.
CreateVector
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
raw_ptr
()),
layout
.
span
().
high_byte
);
m_cur_rst
.
tensor_value_bytes
+=
layout
.
span
().
high_byte
;
}
auto
fbname
=
should_keep_name
?
m_builder
.
CreateSharedString
(
name
)
:
0
;
...
...
@@ -688,14 +693,9 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
auto
&&
loader
=
m_loader
->
m_cur_load_config
->
tensor_value_loader
;
if
(
tensor
->
data
()
&&
tensor
->
data
()
->
size
()
>
0
)
{
if
(
loader
)
{
mgb_log_warn
(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader callback."
);
}
fill_tensor_memory
(
*
ret
,
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
(),
m_loader
->
m_file
->
is_shared_memory
());
m_loader
->
m_file
->
is_shared_memory
()
,
loader
);
}
if
(
tensor
->
name
())
{
m_tensor_map
[
tensor
->
name
()
->
str
()]
=
ret
;
...
...
@@ -737,6 +737,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
shared_pair
.
first
=
tensor
->
name
()
->
str
();
}
auto
loader
=
m_loader
->
m_cur_load_config
->
tensor_value_loader
;
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
()
||
copy_immediatly
)
{
// directly forward CPU memory
shared_tensor_ref
=
std
::
make_shared
<
DeviceTensorND
>
();
...
...
@@ -745,7 +746,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
hv
.
dtype
(
layout
.
dtype
).
resize
(
layout
);
fill_tensor_memory
(
hv
,
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
(),
m_loader
->
m_file
->
is_shared_memory
());
m_loader
->
m_file
->
is_shared_memory
()
,
loader
);
}
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
*
shared_tensor_ref
=
DeviceTensorND
::
make_proxy
(
hv
);
...
...
@@ -761,7 +762,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
hv
.
dtype
(
layout
.
dtype
).
resize
(
layout
);
fill_tensor_memory
(
hv
,
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
(),
m_loader
->
m_file
->
is_shared_memory
());
m_loader
->
m_file
->
is_shared_memory
()
,
loader
);
}
shared_tensor_ref
=
m_device_value_loader
.
make
(
comp_node
,
std
::
move
(
hv
));
}
...
...
@@ -947,7 +948,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
if
(
m_shared_tensor_map
.
empty
())
{
m_shared_tensor_map
.
resize
(
m_model
->
nr_shared_tensor
());
}
else
{
mgb_assert
(
m_shared_tensor_map
.
size
()
=
=
m_model
->
nr_shared_tensor
());
mgb_assert
(
m_shared_tensor_map
.
size
()
>
=
m_model
->
nr_shared_tensor
());
}
SharedTensorAlignMent
tensor_alignment
(
m_model_buf
,
m_file
.
get
(),
...
...
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
浏览文件 @
f7d2017e
...
...
@@ -154,20 +154,29 @@ public:
//! the memory used when load model, but should consider the memory
//! alignment
void
fill_tensor_memory
(
HostTensorND
&
tensor
,
const
uint8_t
*
data
,
size_t
size
,
bool
shared
)
{
HostTensorND
&
tensor
,
const
uint8_t
*
data
,
size_t
size
,
bool
shared
,
GraphLoadConfig
::
TensorValueLoader
loader
)
{
auto
tensor_size
=
tensor
.
layout
().
span
().
high_byte
;
mgb_assert
(
size
==
tensor_size
,
"the size is not match when shared the flatbuffer memory
\n
"
);
auto
ptr
=
reinterpret_cast
<
void
*>
(
const_cast
<
uint8_t
*>
(
data
));
if
(
shared
)
{
HostTensorStorage
storage
;
auto
raw_storage
=
std
::
shared_ptr
<
mgb
::
dt_byte
>
(
static_cast
<
mgb
::
dt_byte
*>
(
ptr
),
[](
void
*
)
{});
storage
.
reset
(
tensor
.
comp_node
(),
size
,
raw_storage
);
tensor
.
reset
(
storage
,
tensor
.
layout
());
if
(
loader
)
{
// call custom loader
void
*
dest_ptr
=
tensor
.
raw_ptr
();
auto
input_file
=
InputFile
::
make_mem_proxy
(
data
,
size
);
loader
(
dest_ptr
,
tensor
.
layout
(),
*
input_file
);
}
else
{
memcpy
(
tensor
.
raw_ptr
(),
data
,
size
);
mgb_assert
(
size
==
tensor_size
,
"the size is not match when shared the flatbuffer memory
\n
"
);
if
(
shared
)
{
HostTensorStorage
storage
;
auto
raw_storage
=
std
::
shared_ptr
<
mgb
::
dt_byte
>
(
static_cast
<
mgb
::
dt_byte
*>
(
ptr
),
[](
void
*
)
{});
storage
.
reset
(
tensor
.
comp_node
(),
size
,
raw_storage
);
tensor
.
reset
(
storage
,
tensor
.
layout
());
}
else
{
memcpy
(
tensor
.
raw_ptr
(),
data
,
size
);
}
}
}
...
...
src/serialization/test/serializer_oss.cpp
浏览文件 @
f7d2017e
...
...
@@ -315,8 +315,10 @@ void test_serializer_custom_loader(GraphDumpFormat format) {
load
();
load
();
ASSERT_EQ
(
2u
,
saved_val
.
size
());
ASSERT_EQ
(
2
,
load_nr_null_ptr
);
// immutable tensor is also shared
ASSERT_EQ
(
4
,
load_nr_call
);
if
(
GraphDumpFormat
::
FLATBUFFERS_V2
!=
format
)
{
ASSERT_EQ
(
2
,
load_nr_null_ptr
);
// immutable tensor is also shared
ASSERT_EQ
(
4
,
load_nr_call
);
}
}
void
test_serializer_many_io_var
(
GraphDumpFormat
format
)
{
...
...
@@ -998,6 +1000,10 @@ TEST(TestSerializer2, ManyIOVarsV2) {
test_serializer_many_io_var
(
GraphDumpFormat
::
FLATBUFFERS_V2
);
}
TEST
(
TestSerializer2
,
CustomLoaderV2
)
{
test_serializer_custom_loader
(
GraphDumpFormat
::
FLATBUFFERS_V2
);
}
TEST
(
TestSerializer2
,
RemoveSetGradV2
)
{
test_serializer_remove_set_grad
(
GraphDumpFormat
::
FLATBUFFERS_V2
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录