Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a694fb33
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a694fb33
编写于
5月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(serialization): implement the new serialization format
GitOrigin-RevId: 00f87f7ccdae7d313d8a97108c3ae712b1c27da3
上级
ca4a5da0
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1210 addition
and
52 deletion
+1210
-52
src/serialization/impl/batched_device_value_loader.cpp
src/serialization/impl/batched_device_value_loader.cpp
+1
-2
src/serialization/impl/serializer.cpp
src/serialization/impl/serializer.cpp
+17
-0
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+2
-49
src/serialization/impl/serializer_oss_common.cpp
src/serialization/impl/serializer_oss_common.cpp
+39
-0
src/serialization/impl/serializer_oss_common.h
src/serialization/impl/serializer_oss_common.h
+32
-0
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+847
-0
src/serialization/include/megbrain/serialization/batched_device_value_loader.h
...lude/megbrain/serialization/batched_device_value_loader.h
+0
-0
src/serialization/include/megbrain/serialization/dump_format.h
...erialization/include/megbrain/serialization/dump_format.h
+1
-0
src/serialization/include/megbrain/serialization/opr_load_dump.h
...ialization/include/megbrain/serialization/opr_load_dump.h
+14
-1
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
...zation/include/megbrain/serialization/oss_opr_load_dump.h
+257
-0
未找到文件。
src/serialization/impl/batched_device_value_loader.cpp
浏览文件 @
a694fb33
#include "batched_device_value_loader.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/arith_helper.h"
namespace
mgb
{
namespace
mgb
{
...
...
src/serialization/impl/serializer.cpp
浏览文件 @
a694fb33
...
@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
...
@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
}
}
std
::
unique_ptr
<
GraphLoader
>
make_fbs_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphLoader
>
make_fbs_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
bool
is_fbs_file
(
InputFile
&
file
);
bool
is_fbs_file
(
InputFile
&
file
);
bool
is_fbs_v2_file
(
InputFile
&
file
);
bool
GraphDumper
::
should_remove_in_dump
(
cg
::
OperatorNodeBase
*
opr
)
{
bool
GraphDumper
::
should_remove_in_dump
(
cg
::
OperatorNodeBase
*
opr
)
{
#if MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
...
@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
...
@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
case
GraphDumpFormat
::
FLATBUFFERS
:
case
GraphDumpFormat
::
FLATBUFFERS
:
#if MGB_ENABLE_FBS_SERIALIZATION
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_dumper
(
std
::
move
(
file
));
return
make_fbs_dumper
(
std
::
move
(
file
));
#endif
MGB_FALLTHRU
case
GraphDumpFormat
::
FLATBUFFERS_V2
:
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_v2_dumper
(
std
::
move
(
file
));
#endif
#endif
MGB_FALLTHRU
MGB_FALLTHRU
default:
default:
...
@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make(
...
@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make(
case
GraphDumpFormat
::
FLATBUFFERS
:
case
GraphDumpFormat
::
FLATBUFFERS
:
#if MGB_ENABLE_FBS_SERIALIZATION
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_loader
(
std
::
move
(
file
));
return
make_fbs_loader
(
std
::
move
(
file
));
#endif
MGB_FALLTHRU
case
GraphDumpFormat
::
FLATBUFFERS_V2
:
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_v2_loader
(
std
::
move
(
file
));
#endif
#endif
MGB_FALLTHRU
MGB_FALLTHRU
default:
default:
...
@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file)
...
@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file)
if
(
is_fbs_file
(
file
))
{
if
(
is_fbs_file
(
file
))
{
return
GraphDumpFormat
::
FLATBUFFERS
;
return
GraphDumpFormat
::
FLATBUFFERS
;
}
}
if
(
is_fbs_v2_file
(
file
))
{
return
GraphDumpFormat
::
FLATBUFFERS_V2
;
}
#endif
#endif
return
{};
return
{};
}
}
...
...
src/serialization/impl/serializer_oss.cpp
浏览文件 @
a694fb33
...
@@ -11,17 +11,16 @@
...
@@ -11,17 +11,16 @@
*/
*/
#if MGB_ENABLE_FBS_SERIALIZATION
#if MGB_ENABLE_FBS_SERIALIZATION
#include "batched_device_value_loader.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/serializer.h"
#include "
megbrain/versi
on.h"
#include "
serializer_oss_comm
on.h"
#include <flatbuffers/flatbuffers.h>
#include <flatbuffers/flatbuffers.h>
...
@@ -33,47 +32,8 @@ using namespace mgb;
...
@@ -33,47 +32,8 @@ using namespace mgb;
using
namespace
mgb
::
serialization
;
using
namespace
mgb
::
serialization
;
namespace
{
namespace
{
constexpr
uint32_t
MGB_VERSION
=
(
MGE_MAJOR
*
1000
+
MGE_MINOR
)
*
100
+
MGE_PATCH
;
constexpr
uint32_t
MGB_MAGIC
=
0x4342474D
;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr
uint32_t
MAGIC_V0
=
0x5342474D
;
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the
// old magic(MAGIC_V0) is false.
bool
magic_compare
=
true
;
bool
magic_compare
=
true
;
template
<
typename
T
>
bool
contains_any_in_set
(
const
SmallVector
<
T
>&
list
,
const
ThinHashSet
<
T
>&
set
)
{
for
(
const
auto
&
x
:
list
)
{
if
(
set
.
count
(
x
))
{
return
true
;
}
}
return
false
;
}
void
check_tensor_value_valid
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
)
{
bool
cond_normal
=
tensor
.
layout
().
format
.
is_default
()
&&
tensor
.
layout
().
is_physical_contiguous
();
bool
cond_lowbit
=
tensor
.
layout
().
dtype
.
is_quantized_lowbit
()
&&
tensor
.
layout
().
format
.
is_lowbit_aligned
()
&&
tensor
.
layout
().
is_contiguous
();
mgb_assert
(
cond_normal
||
cond_lowbit
,
"non-contiguous tensor: name=%s layout=%s"
,
name
.
c_str
(),
tensor
.
layout
().
to_string
().
c_str
());
if
(
tensor
.
dtype
()
==
dtype
::
Float32
())
{
auto
ptr
=
tensor
.
ptr
<
float
>
();
for
(
size_t
i
=
0
,
it
=
tensor
.
shape
().
total_nr_elems
();
i
<
it
;
++
i
)
{
if
(
!
std
::
isfinite
(
ptr
[
i
]))
{
mgb_log_warn
(
"invalid tensor value in %s: %g"
,
name
.
c_str
(),
ptr
[
i
]);
break
;
}
}
}
}
//! feature bits for backward compatibility; default value should be 0
//! feature bits for backward compatibility; default value should be 0
struct
FeatureBits64
{
struct
FeatureBits64
{
//! reserved for new fields
//! reserved for new fields
...
@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
...
@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
return
std
::
make_unique
<
GraphLoaderOSS
>
(
std
::
move
(
file
));
return
std
::
make_unique
<
GraphLoaderOSS
>
(
std
::
move
(
file
));
}
}
bool
is_fbs_file
(
InputFile
&
file
)
{
uint64_t
magic_with_reserved
=
0
;
file
.
read
(
&
magic_with_reserved
,
sizeof
(
magic_with_reserved
));
file
.
skip
(
-
sizeof
(
magic_with_reserved
));
return
(
magic_with_reserved
==
MGB_MAGIC
)
||
(
magic_with_reserved
==
MAGIC_V0
);
}
}
// namespace serialization
}
// namespace serialization
}
// namespace mgb
}
// namespace mgb
...
...
src/serialization/impl/serializer_oss_common.cpp
0 → 100644
浏览文件 @
a694fb33
#if MGB_ENABLE_FBS_SERIALIZATION
#include "serializer_oss_common.h"
namespace
mgb
{
namespace
serialization
{
bool
is_fbs_file
(
InputFile
&
file
)
{
//! check whether the model format is flatbuffer v2
uint64_t
magic_with_reserved
=
0
;
file
.
read
(
&
magic_with_reserved
,
sizeof
(
magic_with_reserved
));
file
.
skip
(
-
sizeof
(
magic_with_reserved
));
return
(
magic_with_reserved
==
MGB_MAGIC
)
||
(
magic_with_reserved
==
MAGIC_V0
);
}
void
check_tensor_value_valid
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
)
{
bool
cond_normal
=
tensor
.
layout
().
format
.
is_default
()
&&
tensor
.
layout
().
is_physical_contiguous
();
bool
cond_lowbit
=
tensor
.
layout
().
dtype
.
is_quantized_lowbit
()
&&
tensor
.
layout
().
format
.
is_lowbit_aligned
()
&&
tensor
.
layout
().
is_contiguous
();
mgb_assert
(
cond_normal
||
cond_lowbit
,
"non-contiguous tensor: name=%s layout=%s"
,
name
.
c_str
(),
tensor
.
layout
().
to_string
().
c_str
());
if
(
tensor
.
dtype
()
==
dtype
::
Float32
())
{
auto
ptr
=
tensor
.
ptr
<
float
>
();
for
(
size_t
i
=
0
,
it
=
tensor
.
shape
().
total_nr_elems
();
i
<
it
;
++
i
)
{
if
(
!
std
::
isfinite
(
ptr
[
i
]))
{
mgb_log_warn
(
"invalid tensor value in %s: %g"
,
name
.
c_str
(),
ptr
[
i
]);
break
;
}
}
}
}
}
// namespace serialization
}
// namespace mgb
#endif
src/serialization/impl/serializer_oss_common.h
0 → 100644
浏览文件 @
a694fb33
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"
namespace
mgb
{
namespace
serialization
{
constexpr
uint32_t
MGB_VERSION
=
(
MGE_MAJOR
*
1000
+
MGE_MINOR
)
*
100
+
MGE_PATCH
;
constexpr
uint32_t
MGB_MAGIC
=
0x4342474D
;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr
uint32_t
MAGIC_V0
=
0x5342474D
;
void
check_tensor_value_valid
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
);
template
<
typename
T
>
bool
contains_any_in_set
(
const
SmallVector
<
T
>&
list
,
const
ThinHashSet
<
T
>&
set
)
{
for
(
const
auto
&
x
:
list
)
{
if
(
set
.
count
(
x
))
{
return
true
;
}
}
return
false
;
}
}
// namespace serialization
}
// namespace mgb
#endif
src/serialization/impl/serializer_oss_v2.cpp
0 → 100644
浏览文件 @
a694fb33
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/utils/hash_ct.h"
#include "megdnn/tensor_format.h"
#include "serializer_oss_common.h"
#include "megbrain/gopt/framework.h"
namespace
mgb
{
namespace
serialization
{
/*!
* \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2
*/
class
PassConvertToCompatible
:
public
gopt
::
Pass
{
ThinHashMap
<
Typeinfo
*
,
thin_function
<
cg
::
OperatorNodeBase
*
(
cg
::
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
m_opr_replace_func
;
gopt
::
VarReplaceCheckFlag
m_var_replace_check_flag
=
gopt
::
VarReplaceCheckFlag
::
CHECK_ALL
;
public:
const
char
*
name
()
const
override
{
return
"PassConvertToCompatible"
;
};
PassConvertToCompatible
&
set_var_replace_check_flag
(
gopt
::
VarReplaceCheckFlag
flag
)
{
m_var_replace_check_flag
=
flag
;
return
*
this
;
}
void
apply
(
gopt
::
OptState
&
state
)
const
override
{
state
.
set_var_replace_check_flag
(
m_var_replace_check_flag
);
auto
rewriter
=
state
.
graph
().
make_rewriter
();
auto
on_opr
=
[
this
,
&
rewriter
](
cg
::
OperatorNodeBase
*
opr
)
{
auto
it
=
m_opr_replace_func
.
find
(
opr
->
dyn_typeinfo
());
if
(
it
!=
m_opr_replace_func
.
end
())
{
VarNodeArray
new_inp
;
new_inp
.
clear
();
new_inp
.
reserve
(
opr
->
input
().
size
());
for
(
auto
i
:
opr
->
input
())
{
new_inp
.
push_back
(
rewriter
.
get_var
(
i
));
}
auto
new_opr
=
(
it
->
second
)(
opr
,
new_inp
);
auto
&&
origin_out
=
opr
->
output
(),
&&
cur_out
=
new_opr
->
output
();
for
(
size_t
i
=
0
;
i
<
std
::
min
(
origin_out
.
size
(),
cur_out
.
size
());
i
++
)
{
rewriter
.
replace_var
(
origin_out
[
i
],
cur_out
[
i
],
nullptr
);
}
}
else
{
rewriter
.
auto_replace_outputs
(
opr
);
}
};
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
}
static
std
::
unique_ptr
<
PassConvertToCompatible
>
make
(
const
SymbolVarArray
&
output_vars
)
{
auto
ret
=
std
::
make_unique
<
PassConvertToCompatible
>
();
// iterate oprs to init
auto
on_opr
=
[
&
](
cg
::
OperatorNodeBase
*
opr
)
{
if
(
!
GraphDumper
::
should_remove_in_dump
(
opr
))
{
auto
registry
=
OprRegistryV2
::
versioned_find_by_typeinfo
(
opr
->
dyn_typeinfo
(),
CURRENT_VERSION
);
mgb_throw_if
(
!
registry
,
cg
::
OperatorNodeExcExtraInfo
::
ExcMaker
{
opr
}.
make
<
MegBrainError
>
,
"serialization as FlatBuffers is not supported for "
"operator %s, typeinfo %p"
,
opr
->
dyn_typeinfo
()
->
name
,
opr
->
dyn_typeinfo
());
if
(
registry
->
converter
)
{
ret
->
m_opr_replace_func
[
opr
->
dyn_typeinfo
()]
=
registry
->
converter
;
}
}
};
cg
::
DepOprIter
dep_opr_iter
{
on_opr
};
for
(
auto
i
:
output_vars
)
{
dep_opr_iter
.
add
(
i
.
node
()
->
owner_opr
());
}
return
ret
;
};
};
namespace
{
fbs
::
v2
::
TensorFormat
get_flatbuffer_tensor_format_type
(
const
TensorLayout
::
Format
&
format
)
{
using
Type
=
megdnn
::
TensorFormat
::
Type
;
switch
(
format
.
type
())
{
case
Type
::
DEFAULT
:
return
fbs
::
v2
::
TensorFormat
::
TensorFormat_DefaultTensorFormat
;
case
Type
::
IMAGE2D_PACK4
:
return
fbs
::
v2
::
TensorFormat
::
TensorFormat_Image2DPackedTensorFormat
;
case
Type
::
LOWBITS_ALIGNED_TO_BYTE
:
return
fbs
::
v2
::
TensorFormat
::
TensorFormat_LowbitsAlignedTensorFormat
;
default:
mgb_throw
(
SerializationError
,
"invalid tensor format type in serialization."
);
}
}
}
// namespace
flatbuffers
::
Offset
<
fbs
::
DType
>
GraphDumperOSSV2
::
build_dtype
(
DType
dtype
)
{
return
fbs
::
intl
::
build_dtype
(
m_builder
,
dtype
);
}
flatbuffers
::
Offset
<
void
>
GraphDumperOSSV2
::
build_tensor_format
(
const
TensorLayout
::
Format
&
format
)
{
using
Type
=
megdnn
::
TensorFormat
::
Type
;
switch
(
format
.
type
())
{
case
Type
::
DEFAULT
:
return
fbs
::
v2
::
CreateDefaultTensorFormat
(
m_builder
).
Union
();
case
Type
::
IMAGE2D_PACK4
:
return
fbs
::
v2
::
CreateImage2DPackedTensorFormat
(
m_builder
,
format
.
as_impl
<
megdnn
::
Image2DPack4TensorFormat
>
()
.
align_axis
())
.
Union
();
case
Type
::
LOWBITS_ALIGNED_TO_BYTE
:
{
auto
size_bite
=
format
.
as_impl
<
megdnn
::
LowbitsAlignedToBytesTensorFormat
>
()
.
size_nbits
();
auto
align_size_in_bits
=
format
.
as_impl
<
megdnn
::
LowbitsAlignedToBytesTensorFormat
>
()
.
align_size_in_bits
();
return
fbs
::
v2
::
CreateLowbitsAlignedTensorFormat
(
m_builder
,
size_bite
,
align_size_in_bits
)
.
Union
();
}
default:
mgb_throw
(
SerializationError
,
"invalid tensor format type in serialization."
);
}
}
flatbuffers
::
Offset
<
fbs
::
v2
::
MiddleTensor
>
GraphDumperOSSV2
::
build_middle_tensor
(
const
SymbolVar
var
)
{
mgb_assert
(
var
.
node
());
auto
fbname
=
m_builder
.
CreateSharedString
(
var
.
node
()
->
name
());
flatbuffers
::
Offset
<
fbs
::
v2
::
MiddleTensor
>
serialized_middle_tensor
;
if
(
var
.
node
()
->
dev_tensor_valid
())
{
auto
layout
=
var
.
node
()
->
layout
();
auto
fshape
=
m_builder
.
CreateVectorScalarCast
<
uint32_t
>
(
layout
.
shape
,
layout
.
ndim
);
auto
fcomp_node
=
fbs
::
v2
::
CreateCompNode
(
m_builder
,
m_builder
.
CreateSharedString
(
var
.
node
()
->
comp_node
().
to_string_logical
()));
auto
fdtype
=
build_dtype
(
layout
.
dtype
);
auto
fformat_type
=
get_flatbuffer_tensor_format_type
(
layout
.
format
);
auto
fformat
=
build_tensor_format
(
layout
.
format
);
serialized_middle_tensor
=
fbs
::
v2
::
CreateMiddleTensor
(
m_builder
,
fbname
,
fshape
,
fcomp_node
,
fdtype
,
fformat_type
,
fformat
);
}
serialized_middle_tensor
=
fbs
::
v2
::
CreateMiddleTensor
(
m_builder
,
fbname
);
return
serialized_middle_tensor
;
}
flatbuffers
::
Offset
<
fbs
::
v2
::
OutputVar
>
GraphDumperOSSV2
::
build_output_var
(
const
SymbolVar
var
)
{
auto
out_node
=
var
.
node
();
if
(
m_var2midtensor_id
.
find
(
var
.
node
())
==
m_var2midtensor_id
.
end
())
{
mgb_assert
(
m_var_remove_in_dump
.
find
(
var
.
node
())
!=
m_var_remove_in_dump
.
end
());
out_node
=
m_var_remove_in_dump
[
var
.
node
()];
}
return
fbs
::
v2
::
CreateOutputVar
(
m_builder
,
m_var2midtensor_id
.
at
(
out_node
),
var
.
node
()
->
id
());
}
void
GraphDumperOSSV2
::
init_oprs_to_dump
(
const
SymbolVarArray
&
endpoints
)
{
m_oprs_to_dump
.
clear
();
// iterate oprs to init
auto
on_opr
=
[
&
](
cg
::
OperatorNodeBase
*
opr
)
{
if
(
should_remove_in_dump
(
opr
))
{
mgb_assert
(
opr
->
input
().
size
()
==
1
);
// Copy input ID to output
for
(
auto
i
:
opr
->
output
())
{
if
(
m_var_remove_in_dump
.
find
(
opr
->
input
(
0
))
!=
m_var_remove_in_dump
.
end
())
{
m_var_remove_in_dump
[
i
]
=
m_var_remove_in_dump
[
opr
->
input
(
0
)];
}
else
{
m_var_remove_in_dump
[
i
]
=
opr
->
input
(
0
);
}
}
}
else
{
auto
registry
=
OprRegistryV2
::
versioned_find_by_typeinfo
(
opr
->
dyn_typeinfo
(),
CURRENT_VERSION
);
if
(
!
registry
||
!
registry
->
dumper
)
{
mgb_throw
(
cg
::
OperatorNodeExcExtraInfo
::
ExcMaker
{
opr
}.
make
<
MegBrainError
>
,
"serialization as FlatBuffers is not supported for "
"operator %s"
,
opr
->
dyn_typeinfo
()
->
name
);
}
m_oprs_to_dump
.
emplace_back
(
opr
,
registry
);
}
};
cg
::
DepOprIter
dep_opr_iter
{
on_opr
};
for
(
auto
i
:
endpoints
)
{
dep_opr_iter
.
add
(
i
.
node
()
->
owner_opr
());
}
}
flatbuffers
::
Offset
<
fbs
::
v2
::
Metadata
>
GraphDumperOSSV2
::
build_metadata
(
const
Metadata
&
metadata
)
{
auto
user_info
=
m_builder
.
CreateSharedString
(
metadata
.
user_info
);
fbs
::
v2
::
MetadataBuilder
builder
(
m_builder
);
builder
.
add_is_valid
(
metadata
.
is_valid
);
builder
.
add_graph_modified
(
metadata
.
graph_modified
);
builder
.
add_optimize_options
(
metadata
.
optimize_options
);
builder
.
add_user_info
(
user_info
);
return
builder
.
Finish
();
}
flatbuffers
::
Offset
<
fbs
::
v2
::
Operator
>
GraphDumperOSSV2
::
build_single_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistryV2
*
registry
)
{
m_cur_opr
=
opr
;
++
m_cur_rst
.
nr_opr
;
using
namespace
flatbuffers
;
Offset
<
Vector
<
uint32_t
>>
inputs
;
if
(
m_cur_opr
->
input
().
size
())
{
std
::
vector
<
uint32_t
>
v
;
v
.
reserve
(
m_cur_opr
->
input
().
size
());
for
(
auto
inp
:
m_cur_opr
->
input
())
{
if
(
m_var2midtensor_id
.
find
(
inp
)
!=
m_var2midtensor_id
.
end
())
{
v
.
emplace_back
(
m_var2midtensor_id
.
at
(
inp
));
}
else
{
mgb_assert
(
m_var_remove_in_dump
.
find
(
inp
)
!=
m_var_remove_in_dump
.
end
(),
"when dump the model, the dependence of var is wrong."
);
v
.
emplace_back
(
m_var2midtensor_id
.
at
(
m_var_remove_in_dump
[
inp
]));
}
}
inputs
=
m_builder
.
CreateVector
(
v
);
}
m_cur_opr_tensor
.
clear
();
m_blobs
.
clear
();
m_cur_opr_param
.
clear
();
m_cur_opr_param_type
.
clear
();
registry
->
dumper
(
*
this
,
*
m_cur_opr
);
Offset
<
Vector
<
Offset
<
fbs
::
v2
::
CompNode
>>>
comp_node
;
auto
&
config
=
m_cur_opr
->
config
();
if
(
config
.
has_comp_node_set
())
{
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
CompNode
>>
cns
;
for
(
const
auto
&
cn
:
config
.
comp_node
())
{
cns
.
emplace_back
(
fbs
::
v2
::
CreateCompNode
(
m_builder
,
m_builder
.
CreateSharedString
(
cn
.
to_string_logical
())));
}
comp_node
=
m_builder
.
CreateVector
(
cns
);
}
Offset
<
String
>
operator_name
;
if
(
m_config
.
keep_op_name
)
{
operator_name
=
m_builder
.
CreateSharedString
(
m_cur_opr
->
name
());
}
auto
output_dtype
=
build_dtype
(
config
.
output_dtype
());
Offset
<
Vector
<
uint32_t
>>
outputs
;
if
(
m_cur_opr
->
output
().
size
())
{
std
::
vector
<
uint32_t
>
v
;
v
.
reserve
(
m_cur_opr
->
output
().
size
());
for
(
auto
out
:
m_cur_opr
->
output
())
{
if
(
!
out
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
auto
fbs_out
=
build_middle_tensor
(
out
);
m_model_middle_tensors
.
push_back
(
fbs_out
);
m_var2midtensor_id
[
out
]
=
m_model_middle_tensors
.
size
()
-
1
;
v
.
emplace_back
(
m_var2midtensor_id
.
at
(
out
));
}
}
outputs
=
m_builder
.
CreateVector
(
v
);
}
Offset
<
Vector
<
Offset
<
fbs
::
v2
::
Tensor
>>>
tensors
;
if
(
m_cur_opr_tensor
.
size
())
tensors
=
m_builder
.
CreateVector
(
m_cur_opr_tensor
);
//! the blobs data is used by custom data
//! m_blobs will be filled by the Operator dumper function
Offset
<
Vector
<
Offset
<
fbs
::
v2
::
Blob
>>>
blobs
;
if
(
m_blobs
.
size
())
blobs
=
m_builder
.
CreateVector
(
m_blobs
);
Offset
<
Vector
<
uint8_t
>>
additional_params_type
;
Offset
<
Vector
<
Offset
<
void
>>>
additional_params
;
auto
param_cnt
=
m_cur_opr_param_type
.
size
();
if
(
param_cnt
>
1
)
{
additional_params_type
=
m_builder
.
CreateVectorScalarCast
<
uint8_t
>
(
m_cur_opr_param_type
.
data
()
+
1
,
param_cnt
-
1
);
additional_params
=
m_builder
.
CreateVector
(
m_cur_opr_param
.
data
()
+
1
,
param_cnt
-
1
);
}
auto
opr_type
=
m_builder
.
CreateSharedString
(
registry
->
name
);
fbs
::
v2
::
OperatorBuilder
builder
(
m_builder
);
builder
.
add_type
(
opr_type
);
builder
.
add_type_id
(
registry
->
type_id
);
builder
.
add_inputs
(
inputs
);
builder
.
add_outputs
(
outputs
);
if
(
m_config
.
keep_opr_priority
)
{
builder
.
add_priority
(
opr
->
node_prop
().
attribute
().
priority
);
}
builder
.
add_comp_node
(
comp_node
);
builder
.
add_opr_version
(
registry
->
get_version
());
builder
.
add_name
(
operator_name
);
builder
.
add_output_dtype
(
output_dtype
);
if
(
param_cnt
>
0
)
{
builder
.
add_param_type
(
m_cur_opr_param_type
[
0
]);
builder
.
add_param
(
m_cur_opr_param
[
0
]);
}
if
(
param_cnt
>
1
)
{
builder
.
add_additional_params_type
(
additional_params_type
);
builder
.
add_additional_params
(
additional_params
);
}
builder
.
add_tensors
(
tensors
);
builder
.
add_custom_data
(
blobs
);
m_cur_opr
=
nullptr
;
return
builder
.
Finish
();
}
SymbolVarArray
GraphDumperOSSV2
::
converter_all_opr_to_compatiable
(
const
SymbolVarArray
&
output_vars
)
{
gopt
::
GraphOptimizer
optimizer
;
VarNodeArray
rets_var
;
for
(
auto
&
symbolvar
:
output_vars
)
{
rets_var
.
push_back
(
symbolvar
.
node
());
}
optimizer
.
add_pass
(
PassConvertToCompatible
::
make
(
output_vars
));
optimizer
.
apply_inplace
(
rets_var
);
SymbolVarArray
dst_vars
;
for
(
auto
&
var
:
rets_var
)
{
dst_vars
.
push_back
({
var
});
}
return
dst_vars
;
}
GraphDumper
::
DumpResult
GraphDumperOSSV2
::
dump
(
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
,
const
Metadata
&
metadata
)
{
mgb_throw_if
(
output_vars
.
empty
(),
SerializationError
,
"Can't dump empty graph"
);
auto
&&
new_output_vars
=
converter_all_opr_to_compatiable
(
output_vars
);
auto
begin_pos
=
m_file
->
tell
();
m_config
=
config
;
m_builder
.
Reset
();
m_output_vars
.
clear
();
m_cur_rst
=
{};
m_used_input_names
.
clear
();
m_used_param_names
.
clear
();
m_var_remove_in_dump
.
clear
();
m_model_middle_tensors
.
clear
();
m_var2midtensor_id
.
clear
();
m_nr_shared_tensor
=
0
;
// process output vars
bool
keep_output_var_name
=
m_config
.
keep_var_name
>=
1
;
std
::
unordered_set
<
std
::
string
>
output_var_names
;
for
(
auto
i
:
new_output_vars
)
{
mgb_assert
(
!
i
.
node
()
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
),
"can not dump var with VOLATILE_CONTENT flag: %s"
,
cg
::
dump_var_info
({
i
.
node
()}).
c_str
());
if
(
m_output_vars
.
insert
(
i
.
node
()).
second
&&
keep_output_var_name
)
{
auto
name_ins
=
output_var_names
.
insert
(
i
.
node
()
->
name
()).
second
;
mgb_assert
(
name_ins
,
"duplicated output var name: %s"
,
i
.
node
()
->
cname
());
}
}
// Dump metadata
auto
fbmeta
=
build_metadata
(
metadata
);
// Dump operators
init_oprs_to_dump
(
new_output_vars
);
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
Operator
>>
oprs
;
for
(
auto
&&
i
:
m_oprs_to_dump
)
{
oprs
.
emplace_back
(
build_single_opr
(
i
.
first
,
i
.
second
));
}
auto
fb_oprs
=
m_builder
.
CreateVector
(
oprs
);
// Dump output vars
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
OutputVar
>>
output_vars_idx
;
output_vars_idx
.
reserve
(
new_output_vars
.
size
());
for
(
auto
i
:
new_output_vars
)
{
auto
foutput_vars_idx
=
build_output_var
(
i
);
output_vars_idx
.
push_back
(
foutput_vars_idx
);
}
auto
fb_output_vars
=
m_builder
.
CreateVector
(
output_vars_idx
);
auto
fb_mid_tensor
=
m_builder
.
CreateVector
(
m_model_middle_tensors
);
fbs
::
v2
::
ModelBuilder
model
(
m_builder
);
model
.
add_mge_version
(
MGB_VERSION
);
model
.
add_oprs
(
fb_oprs
);
model
.
add_middle_tensors
(
fb_mid_tensor
);
model
.
add_output_vars_idx
(
fb_output_vars
);
model
.
add_nr_shared_tensor
(
m_nr_shared_tensor
);
model
.
add_metadata
(
fbmeta
);
m_builder
.
FinishSizePrefixed
(
model
.
Finish
(),
fbs
::
v2
::
ModelIdentifier
());
// Write serialized fbs::Graph
m_file
->
write
(
m_builder
.
GetBufferPointer
(),
m_builder
.
GetSize
());
// Finalize DumpResult
auto
&&
ret
=
m_cur_rst
;
for
(
size_t
i
=
0
;
i
<
new_output_vars
.
size
();
i
++
)
{
ret
.
outputs
.
emplace_back
(
keep_output_var_name
?
new_output_vars
[
i
].
node
()
->
cname
()
:
ssprintf
(
"unnamed%zu"
,
i
));
}
std
::
sort
(
ret
.
inputs
.
begin
(),
ret
.
inputs
.
end
());
mgb_assert
(
ret
.
nr_opr
==
m_oprs_to_dump
.
size
());
ret
.
tot_bytes
=
m_file
->
tell
()
-
begin_pos
;
return
ret
;
}
void
GraphDumperOSSV2
::
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
{
using
namespace
flatbuffers
;
using
Meth
=
TensorWriteMethod
;
mgb_assert
(
(
method
==
Meth
::
VALUE_ANONYMOUS
)
^
(
!
name
.
empty
()),
"name must be non-empty for non Meth::VALUE_ANONYMOUS tensors"
);
bool
has_value
=
method
!=
Meth
::
META_INPUT
;
bool
should_keep_name
=
true
;
switch
(
method
)
{
case
Meth
::
VALUE_ANONYMOUS
:
should_keep_name
=
false
;
break
;
case
Meth
::
VALUE_SHARED
:
should_keep_name
=
m_config
.
keep_param_name
;
++
m_nr_shared_tensor
;
if
(
m_config
.
keep_param_name
)
{
mgb_assert
(
m_used_param_names
.
insert
(
name
).
second
,
"duplicated VALUE_SHARED tensor name: %s"
,
name
.
c_str
());
m_cur_rst
.
params
.
emplace_back
(
name
);
}
break
;
case
Meth
::
META_INPUT
:
case
Meth
::
VALUE_INPUT
:
mgb_assert
(
!
name
.
empty
(),
"empty input tensor name"
);
mgb_assert
(
m_used_input_names
.
insert
(
name
).
second
,
"duplicated input tensor name: %s"
,
name
.
c_str
());
m_cur_rst
.
inputs
.
emplace_back
(
name
);
break
;
}
auto
&
layout
=
tensor
.
layout
();
flatbuffers
::
Offset
<
flatbuffers
::
Vector
<
uint8_t
>>
data
;
if
(
has_value
)
{
check_tensor_value_valid
(
name
,
tensor
);
auto
&&
dumper
=
m_config
.
tensor_value_dumper
;
if
(
dumper
)
{
mgb_log_warn
(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value dumper"
);
}
data
=
m_builder
.
CreateVector
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
raw_ptr
()),
layout
.
span
().
high_byte
);
m_cur_rst
.
tensor_value_bytes
+=
layout
.
span
().
high_byte
;
}
auto
fbname
=
should_keep_name
?
m_builder
.
CreateSharedString
(
name
)
:
0
;
auto
fshape
=
m_builder
.
CreateVectorScalarCast
<
uint32_t
>
(
layout
.
shape
,
layout
.
ndim
);
auto
fcomp_node
=
fbs
::
v2
::
CreateCompNode
(
m_builder
,
m_builder
.
CreateSharedString
(
tensor
.
comp_node
().
to_string_logical
()));
auto
fdtype
=
build_dtype
(
layout
.
dtype
);
auto
fformat_type
=
get_flatbuffer_tensor_format_type
(
layout
.
format
);
auto
fformat
=
build_tensor_format
(
layout
.
format
);
auto
serialized_tensor
=
fbs
::
v2
::
CreateTensor
(
m_builder
,
fbname
,
fshape
,
fcomp_node
,
fdtype
,
fformat_type
,
fformat
,
data
);
m_cur_opr_tensor
.
emplace_back
(
serialized_tensor
);
}
void
GraphDumperOSSV2
::
dump_buf_with_len
(
const
void
*
data
,
uint32_t
size
)
{
auto
blob
=
fbs
::
v2
::
CreateBlob
(
m_builder
,
m_builder
.
CreateVector
(
static_cast
<
const
uint8_t
*>
(
data
),
size
));
m_blobs
.
emplace_back
(
blob
);
}
// ----------------------------- Loader --------------------------------------
CompNode
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_comp_node
(
const
fbs
::
v2
::
CompNode
*
comp_node
)
{
mgb_assert
(
comp_node
);
if
(
!
comp_node
->
logical_locator
())
return
{};
auto
loc
=
CompNode
::
Locator
::
parse
(
comp_node
->
logical_locator
()
->
str
());
m_loader
->
m_cur_load_config
->
comp_node_mapper
(
loc
);
return
CompNode
::
load
(
loc
);
}
TensorFormat
load_tensor_format
(
const
fbs
::
v2
::
TensorFormat
fformat_type
,
const
void
*
fformat
,
const
CompNode
&
comp_node
)
{
switch
(
fformat_type
)
{
case
fbs
::
v2
::
TensorFormat_DefaultTensorFormat
:
return
megdnn
::
DefaultTensorFormat
::
make
();
case
fbs
::
v2
::
TensorFormat_Image2DPackedTensorFormat
:
{
auto
image_format
=
static_cast
<
const
fbs
::
v2
::
Image2DPackedTensorFormat
*>
(
fformat
);
auto
handle
=
MegDNNHandle
::
get
(
CompNodeEnv
::
from_comp_node
(
comp_node
)).
handle
();
return
megdnn
::
Image2DPack4TensorFormat
::
make
(
image_format
->
align_axis
(),
handle
);
}
case
fbs
::
v2
::
TensorFormat_LowbitsAlignedTensorFormat
:
{
auto
lowbit_format
=
static_cast
<
const
fbs
::
v2
::
LowbitsAlignedTensorFormat
*>
(
fformat
);
return
megdnn
::
LowbitsAlignedToBytesTensorFormat
::
make
(
lowbit_format
->
size_nbits
());
}
default:
mgb_throw
(
SerializationError
,
"invalid tensor format type in serialization."
);
}
}
TensorLayout
load_tensor_layout
(
const
fbs
::
v2
::
Tensor
*
tensor
,
const
CompNode
&
comp_node
)
{
TensorLayout
layout
;
if
(
tensor
->
shape
())
{
layout
.
ndim
=
tensor
->
shape
()
->
size
();
std
::
copy
(
tensor
->
shape
()
->
begin
(),
tensor
->
shape
()
->
end
(),
layout
.
shape
);
}
if
(
tensor
->
dtype
())
{
// modify data type inplace for TensorLayout
layout
.
modify_dtype_inplace
(
fbs
::
intl
::
load_dtype
(
tensor
->
dtype
()));
}
if
(
tensor
->
format
()
&&
tensor
->
format_type
())
{
layout
.
format
=
load_tensor_format
(
tensor
->
format_type
(),
tensor
->
format
(),
comp_node
);
}
layout
.
init_contiguous_stride
();
return
layout
;
}
//! the opr loader should make sure the exist of tensors and the number of
//! tensor, here just assert it.
std
::
shared_ptr
<
HostTensorND
>
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_tensor
()
{
mgb_assert
(
m_current_opr
->
tensors
()
&&
m_cur_opr_tensor_cnt
<
m_current_opr
->
tensors
()
->
size
());
auto
tensor
=
m_current_opr
->
tensors
()
->
Get
(
m_cur_opr_tensor_cnt
++
);
auto
comp_node
=
load_comp_node
(
tensor
->
comp_node
());
auto
layout
=
load_tensor_layout
(
tensor
,
comp_node
);
auto
ret
=
std
::
make_shared
<
HostTensorND
>
(
comp_node
,
layout
);
auto
&&
loader
=
m_loader
->
m_cur_load_config
->
tensor_value_loader
;
if
(
tensor
->
data
()
&&
tensor
->
data
()
->
size
()
>
0
)
{
if
(
loader
)
{
mgb_log_warn
(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader"
);
}
memcpy
(
ret
->
raw_ptr
(),
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
());
}
if
(
tensor
->
name
())
{
m_tensor_map
[
tensor
->
name
()
->
str
()]
=
ret
;
}
if
(
auto
&&
mod
=
m_loader
->
m_cur_load_config
->
tensor_modifier
)
{
bool
has_value
=
false
;
if
(
tensor
&&
tensor
->
data
())
{
has_value
=
tensor
->
data
()
->
size
()
!=
0
;
}
mod
(
tensor
->
name
()
?
tensor
->
name
()
->
str
()
:
""
,
has_value
,
*
ret
);
}
return
ret
;
}
std
::
shared_ptr
<
DeviceTensorND
>
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_tensor_shared
()
{
mgb_assert
(
m_current_opr
->
tensors
()
&&
m_cur_opr_tensor_cnt
<
m_current_opr
->
tensors
()
->
size
());
auto
tensor
=
m_current_opr
->
tensors
()
->
Get
(
m_cur_opr_tensor_cnt
++
);
auto
comp_node
=
load_comp_node
(
tensor
->
comp_node
());
auto
layout
=
load_tensor_layout
(
tensor
,
comp_node
);
mgb_assert
(
tensor
->
data
());
auto
&&
shared_pair
=
m_loader
->
m_shared_tensor_map
.
at
(
m_cur_shared_tensor_idx
++
);
auto
&&
shared_tensor_ref
=
shared_pair
.
second
[
comp_node
.
mem_node
()];
if
(
shared_tensor_ref
)
{
if
(
shared_tensor_ref
->
comp_node
()
==
comp_node
)
return
shared_tensor_ref
;
// same mem node but different comp node, change comp node and share
// value
auto
ret
=
std
::
make_shared
<
DeviceTensorND
>
(
*
shared_tensor_ref
);
ret
->
comp_node
(
comp_node
);
return
ret
;
}
if
(
tensor
->
name
())
{
shared_pair
.
first
=
tensor
->
name
()
->
str
();
}
if
(
comp_node
.
mem_node
()
==
CompNode
::
default_cpu
().
mem_node
())
{
// directly forward CPU memory
HostTensorND
hv
{
comp_node
};
if
(
tensor
->
data
()
&&
tensor
->
data
()
->
size
()
>
0
)
{
hv
.
dtype
(
layout
.
dtype
).
resize
(
layout
);
memcpy
(
hv
.
raw_ptr
(),
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
());
}
shared_tensor_ref
=
std
::
make_shared
<
DeviceTensorND
>
();
*
shared_tensor_ref
=
DeviceTensorND
::
make_proxy
(
hv
);
}
else
{
// use lazy load for non-CPU devices
HostTensorND
hv
{
CompNode
::
default_cpu
()};
if
(
tensor
->
data
()
&&
tensor
->
data
()
->
size
()
>
0
)
{
hv
.
dtype
(
layout
.
dtype
).
resize
(
layout
);
memcpy
(
hv
.
raw_ptr
(),
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
());
}
shared_tensor_ref
=
m_device_value_loader
.
make
(
comp_node
,
std
::
move
(
hv
));
}
return
shared_tensor_ref
;
}
Metadata
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_metadata
()
{
const
auto
*
fbmeta
=
m_loader
->
m_model
->
metadata
();
Metadata
ret
;
if
(
fbmeta
)
{
ret
.
is_valid
=
fbmeta
->
is_valid
();
ret
.
graph_modified
=
fbmeta
->
graph_modified
();
if
(
fbmeta
->
user_info
())
{
ret
.
user_info
=
fbmeta
->
user_info
()
->
str
();
ret
.
has_user_info
=
true
;
}
if
(
fbmeta
->
optimize_options
())
{
ret
.
optimize_options
=
fbmeta
->
optimize_options
();
ret
.
optimized_for_inference
=
true
;
}
}
return
ret
;
}
void
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_single_opr
(
const
fbs
::
v2
::
Operator
*
fbopr
)
{
m_cur_opr_tensor_cnt
=
0
;
m_cur_opr_blob_cnt
=
0
;
m_cur_opr_param_cnt
=
0
;
OperatorNodeConfig
config
;
if
(
fbopr
->
output_dtype
())
{
config
.
output_dtype
(
fbs
::
intl
::
load_dtype
(
fbopr
->
output_dtype
()));
}
if
(
fbopr
->
name
())
{
config
.
name
(
fbopr
->
name
()
->
str
());
}
if
(
fbopr
->
comp_node
())
{
auto
cnt
=
fbopr
->
comp_node
()
->
size
();
cg
::
OperatorNodeConfig
::
CompNodeArray
comp_node_arr
(
cnt
);
for
(
size_t
i
=
0
;
i
<
cnt
;
i
++
)
{
CompNode
cn
{};
auto
node
=
fbopr
->
comp_node
()
->
Get
(
i
);
if
(
node
)
{
cn
=
load_comp_node
(
node
);
}
comp_node_arr
[
i
]
=
cn
;
}
config
.
comp_node_arr
(
comp_node_arr
);
}
//! opr version must be exist
uint8_t
opr_version
=
fbopr
->
opr_version
();
auto
type_id
=
fbopr
->
type_id
();
auto
opr_type
=
fbopr
->
type
()
->
str
();
const
OprRegistryV2
*
registry
=
OprRegistryV2
::
versioned_find_by_id
(
type_id
,
opr_version
);
mgb_throw_if
(
!
registry
,
SerializationError
,
"failed to find opr with type %s id is %zu, use python env "
"config.dump_registered_oprs() to get a dict that maps from "
"opr id to opr name"
,
fbopr
->
type
()
->
str
().
c_str
(),
type_id
);
// load inputs
VarNodeArray
inputs
;
if
(
fbopr
->
inputs
())
{
inputs
.
resize
(
fbopr
->
inputs
()
->
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
inputs
[
i
]
=
m_id2varnode
.
at
(
fbopr
->
inputs
()
->
Get
(
i
));
}
}
// call loader
auto
accessor
=
registry
->
loader
(
*
this
,
inputs
,
config
);
auto
opr
=
accessor
.
opr
();
// check opr type; note that:
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs
// 2. due to some optimization, an opr may be replaced by ImmutableTensor
mgb_assert
(
opr
&&
(
opr
->
dyn_typeinfo
()
==
registry
->
type
||
!
registry
->
type
||
opr
->
same_type
<
opr
::
ImmutableTensor
>
()),
"got_type=%s expected_type=%s"
,
opr
?
opr
->
dyn_typeinfo
()
->
name
:
nullptr
,
registry
->
type
->
name
);
// record output vars; read output names
size_t
i
=
0
;
for
(
auto
ovar
:
accessor
.
output
())
{
if
(
!
ovar
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
m_id2varnode
.
push_back
(
ovar
);
if
(
fbopr
->
outputs
())
{
auto
id
=
fbopr
->
outputs
()
->
Get
(
i
);
mgb_assert
(
m_id2varnode
.
size
()
-
1
==
fbopr
->
outputs
()
->
Get
(
i
),
"id2var is %zu, fbs get id is %d
\n
"
,
m_id2varnode
.
size
()
-
1
,
fbopr
->
outputs
()
->
Get
(
i
));
if
(
m_middle_tensors
.
size
()
>
i
)
{
auto
name
=
m_middle_tensors
[
id
]
->
name
()
->
str
();
ovar
->
name
(
name
);
}
}
i
++
;
}
}
opr
->
node_prop
().
attribute
().
priority
=
fbopr
->
priority
();
}
GraphLoader
::
LoadResult
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_oprs
()
{
// load oprs
const
auto
*
oprs
=
m_loader
->
m_model
->
oprs
();
{
// inplace arith graph optimization is disabled during opr load
// it tries to restore the same graph as it was dumped
// see test TestSerializer2.LOGEXP for example
GraphLoader
::
ScopedGraphOptDisabler
_
(
m_graph
);
for
(
flatbuffers
::
uoffset_t
i
=
0
;
i
<
oprs
->
size
();
++
i
)
{
m_current_opr
=
oprs
->
Get
(
i
);
load_single_opr
(
m_current_opr
);
}
}
// batched loading device values
m_device_value_loader
.
apply
();
LoadResult
ret
;
ret
.
graph
=
m_graph
;
ret
.
tensor_map
=
m_tensor_map
;
const
auto
*
outputs
=
m_loader
->
m_model
->
output_vars_idx
();
ret
.
output_var_list
.
resize
(
outputs
->
size
());
for
(
flatbuffers
::
uoffset_t
i
=
0
;
i
<
outputs
->
size
();
i
++
)
{
auto
out
=
outputs
->
Get
(
i
);
auto
var
=
m_id2varnode
.
at
(
out
->
compact_id
());
ret
.
output_var_map
[
var
->
name
()]
=
var
;
ret
.
output_var_map_id
[
out
->
original_id
()]
=
var
;
ret
.
output_var_list
[
i
]
=
var
;
}
mgb_assert
(
m_cur_shared_tensor_idx
==
m_loader
->
m_shared_tensor_map
.
size
());
return
ret
;
}
void
GraphLoaderOSSV2
::
OprLoadContextImpl
::
load_middle_tensor
()
{
auto
model
=
m_loader
->
m_model
;
if
(
model
->
middle_tensors
())
{
for
(
unsigned
int
i
=
0
;
i
<
m_loader
->
m_model
->
middle_tensors
()
->
size
();
i
++
)
{
m_middle_tensors
.
push_back
(
model
->
middle_tensors
()
->
Get
(
i
));
}
}
}
GraphLoader
::
LoadResult
GraphLoaderOSSV2
::
load
(
const
LoadConfig
&
config
,
bool
rewind
)
{
mgb_assert
(
m_file
);
m_cur_load_config
=
&
config
;
if
(
rewind
)
{
m_file
->
rewind
();
}
// Read fbs::Graph
uint32_t
size
;
m_file
->
read
(
&
size
,
sizeof
(
size
));
m_model_buf
=
m_file
->
read_shared
(
size
);
mgb_throw_if
(
!
fbs
::
v2
::
ModelBufferHasIdentifier
(
m_model_buf
.
data
()),
SerializationError
,
"invalid fbs model"
);
{
flatbuffers
::
Verifier
verifier
(
static_cast
<
const
uint8_t
*>
(
m_model_buf
.
data
()),
m_model_buf
.
size
());
mgb_throw_if
(
!
fbs
::
v2
::
VerifyModelBuffer
(
verifier
),
SerializationError
,
"model verification failed (invalid or corrupted model?)"
);
}
m_model
=
fbs
::
v2
::
GetModel
(
m_model_buf
.
data
());
m_mgb_version
=
m_model
->
mge_version
();
if
(
m_model
->
mge_version
()
>
MGB_VERSION
)
{
mgb_log_warn
(
"loading model from future runtime: version=%u "
"model_version=%u"
,
MGB_VERSION
,
m_model
->
mge_version
());
}
if
(
m_shared_tensor_map
.
empty
())
{
m_shared_tensor_map
.
resize
(
m_model
->
nr_shared_tensor
());
}
else
{
mgb_assert
(
m_shared_tensor_map
.
size
()
==
m_model
->
nr_shared_tensor
());
}
OprLoadContextImpl
ctx
{
this
,
m_model
->
mge_version
()};
ctx
.
load_middle_tensor
();
auto
metadata
=
ctx
.
load_metadata
();
auto
result
=
ctx
.
load_oprs
();
result
.
metadata
=
metadata
;
m_model_loaded
=
true
;
result
.
graph_compile_ahead
();
return
result
;
}
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
)
{
return
std
::
make_unique
<
GraphDumperOSSV2
>
(
std
::
move
(
file
));
}
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
)
{
return
std
::
make_unique
<
GraphLoaderOSSV2
>
(
std
::
move
(
file
));
}
bool
is_fbs_v2_file
(
InputFile
&
file
)
{
constexpr
size_t
identifier_length
=
25
;
char
identifier
[
identifier_length
];
file
.
read
(
identifier
,
identifier_length
);
file
.
skip
(
-
identifier_length
);
//! skip the size in prefix of the file
return
fbs
::
v2
::
ModelBufferHasIdentifier
(
identifier
+
sizeof
(
uint32_t
));
}
}
// namespace serialization
}
// namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/serialization/i
mpl
/batched_device_value_loader.h
→
src/serialization/i
nclude/megbrain/serialization
/batched_device_value_loader.h
浏览文件 @
a694fb33
文件已移动
src/serialization/include/megbrain/serialization/dump_format.h
浏览文件 @
a694fb33
...
@@ -5,6 +5,7 @@ namespace serialization {
...
@@ -5,6 +5,7 @@ namespace serialization {
enum
class
GraphDumpFormat
{
enum
class
GraphDumpFormat
{
FLATBUFFERS
,
FLATBUFFERS
,
FLATBUFFERS_V2
,
};
};
}
// namespace serialization
}
// namespace serialization
...
...
src/serialization/include/megbrain/serialization/opr_load_dump.h
浏览文件 @
a694fb33
...
@@ -20,8 +20,12 @@ class FlatBufferBuilder;
...
@@ -20,8 +20,12 @@ class FlatBufferBuilder;
}
// namespace flatbuffers
}
// namespace flatbuffers
namespace
mgb
{
namespace
mgb
{
namespace
serialization
{
constexpr
uint8_t
CURRENT_VERSION
=
2u
;
constexpr
uint8_t
BEGIN_VERSION
=
0u
;
constexpr
uint8_t
VERSION_1
=
1u
;
constexpr
uint8_t
VERSION_2
=
2u
;
namespace
serialization
{
namespace
fbs
{
namespace
fbs
{
template
<
typename
T
>
template
<
typename
T
>
struct
OperatorParamTraits
;
struct
OperatorParamTraits
;
...
@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData {
...
@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData {
friend
class
OprLoadContextRawPOD
;
friend
class
OprLoadContextRawPOD
;
friend
class
OprLoadContextFlatBuffers
;
friend
class
OprLoadContextFlatBuffers
;
protected:
virtual
~
OprLoadContext
()
=
default
;
public:
public:
//! get current computing graph
//! get current computing graph
virtual
ComputingGraph
&
graph
()
=
0
;
virtual
ComputingGraph
&
graph
()
=
0
;
...
@@ -224,6 +231,12 @@ public:
...
@@ -224,6 +231,12 @@ public:
*/
*/
virtual
SharedBuffer
load_shared_buf_with_len
()
=
0
;
virtual
SharedBuffer
load_shared_buf_with_len
()
=
0
;
/*!
* \brief get the serialization data of the current opr
*
*/
virtual
const
void
*
get_current_opr_data
()
{
return
nullptr
;
};
/*!
/*!
* \brief read a param and check that tag matches
* \brief read a param and check that tag matches
*/
*/
...
...
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
0 → 100644
浏览文件 @
a694fb33
#pragma once
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h"
#define CAST_TO_FBS_V2_CTX(cvt) static_cast<GraphLoaderOSSV2::OprLoadContextImpl&>(ctx)
namespace
mgb
{
namespace
serialization
{
class
GraphDumperOSSV2
final
:
public
GraphDumper
,
OprDumpContextFlatBuffers
{
const
std
::
unique_ptr
<
OutputFile
>
m_file
;
flatbuffers
::
FlatBufferBuilder
m_builder
;
DumpConfig
m_config
;
DumpResult
m_cur_rst
;
size_t
m_nr_shared_tensor
;
std
::
vector
<
std
::
pair
<
cg
::
OperatorNodeBase
*
,
const
OprRegistryV2
*>>
m_oprs_to_dump
;
ThinHashMap
<
VarNode
*
,
VarNode
*>
m_var_remove_in_dump
;
//! set of output vars specified by user
ThinHashSet
<
VarNode
*>
m_output_vars
;
std
::
unordered_set
<
std
::
string
>
m_used_input_names
,
m_used_param_names
;
//! current opr to be dumped
cg
::
OperatorNodeBase
*
m_cur_opr
=
nullptr
;
// Will be filled in dump_tensor
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
Tensor
>>
m_cur_opr_tensor
;
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
Blob
>>
m_blobs
;
std
::
vector
<
fbs
::
v2
::
OperatorParam
>
m_cur_opr_param_type
;
std
::
vector
<
flatbuffers
::
Offset
<
void
>>
m_cur_opr_param
;
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
MiddleTensor
>>
m_model_middle_tensors
;
ThinHashMap
<
VarNode
*
,
size_t
>
m_var2midtensor_id
;
SymbolVarArray
converter_all_opr_to_compatiable
(
const
SymbolVarArray
&
output_vars
);
void
init_oprs_to_dump
(
const
SymbolVarArray
&
endpoints
);
flatbuffers
::
Offset
<
fbs
::
v2
::
Metadata
>
build_metadata
(
const
Metadata
&
metadata
);
flatbuffers
::
Offset
<
fbs
::
v2
::
Operator
>
build_single_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistryV2
*
registry
);
flatbuffers
::
Offset
<
fbs
::
DType
>
build_dtype
(
DType
dtype
);
public:
GraphDumperOSSV2
(
std
::
unique_ptr
<
OutputFile
>
file
)
:
m_file
{
std
::
move
(
file
)}
{}
DumpResult
dump
(
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
=
{},
const
Metadata
&
metadata
=
{})
override
;
const
GraphDumpConfig
&
config
()
const
override
{
return
m_config
;
}
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
override
;
void
append_param
(
uint32_t
type
,
uint32_t
value
)
override
{
static_assert
(
std
::
is_same
<
uint32_t
,
flatbuffers
::
uoffset_t
>::
value
,
"append_param depends on uoffset_t being uint32_t"
);
static_assert
(
std
::
is_standard_layout
<
flatbuffers
::
Offset
<
void
>>::
value
,
"append_param depends on flatbuffers::Offset having "
"standard memory layout"
);
mgb_assert
(
type
!=
fbs
::
v2
::
OperatorParam_NONE
);
m_cur_opr_param_type
.
emplace_back
(
static_cast
<
fbs
::
v2
::
OperatorParam
>
(
type
));
m_cur_opr_param
.
emplace_back
(
value
);
}
flatbuffers
::
FlatBufferBuilder
&
builder
()
override
{
return
m_builder
;
}
void
dump_buf_with_len
(
const
void
*
data
,
uint32_t
size
)
override
;
GraphDumpFormat
format
()
const
override
{
return
GraphDumpFormat
::
FLATBUFFERS_V2
;
}
flatbuffers
::
Offset
<
fbs
::
v2
::
MiddleTensor
>
build_middle_tensor
(
const
SymbolVar
var
);
flatbuffers
::
Offset
<
fbs
::
v2
::
OutputVar
>
build_output_var
(
const
SymbolVar
var
);
flatbuffers
::
Offset
<
void
>
build_tensor_format
(
const
TensorLayout
::
Format
&
format
);
void
set_current_opr
(
cg
::
OperatorNodeBase
*
cur_opr
)
{
m_cur_opr
=
cur_opr
;
}
};
// ----------------------------- Loader --------------------------------------
class
GraphLoaderOSSV2
final
:
public
GraphLoader
{
const
LoadConfig
*
m_cur_load_config
=
nullptr
;
std
::
unique_ptr
<
InputFile
>
m_file
;
SharedBuffer
m_model_buf
{{},
0
};
const
fbs
::
v2
::
Model
*
m_model
;
SharedTensorIDMap
m_shared_tensor_map
;
uint32_t
m_mgb_version
=
0
;
bool
m_model_loaded
=
false
;
void
verify
();
public:
class
OprLoadContextImpl
;
friend
class
OprLoadContextImpl
;
GraphLoaderOSSV2
(
std
::
unique_ptr
<
InputFile
>
input_file
)
:
m_file
{
std
::
move
(
input_file
)}
{}
std
::
unique_ptr
<
InputFile
>
reset_file
(
std
::
unique_ptr
<
InputFile
>
file
)
override
{
file
.
swap
(
m_file
);
return
file
;
}
LoadResult
load
(
const
LoadConfig
&
config
,
bool
rewind
)
override
;
const
SharedTensorIDMap
&
shared_tensor_id_map
()
const
override
{
mgb_assert
(
m_model_loaded
,
"graph not loaded yet"
);
return
m_shared_tensor_map
;
}
GraphDumpFormat
format
()
const
override
{
return
GraphDumpFormat
::
FLATBUFFERS_V2
;
}
};
class
GraphLoaderOSSV2
::
OprLoadContextImpl
final
:
public
OprLoadContextFlatBuffers
{
GraphLoaderOSSV2
*
const
m_loader
;
size_t
m_cur_shared_tensor_idx
=
0
;
std
::
shared_ptr
<
ComputingGraph
>
m_graph
;
LoadResult
::
TensorMap
m_tensor_map
;
VarNodeArray
m_id2varnode
;
std
::
vector
<
const
fbs
::
v2
::
MiddleTensor
*>
m_middle_tensors
;
BatchedDeviceValueLoader
m_device_value_loader
;
const
fbs
::
v2
::
Operator
*
m_current_opr
;
size_t
m_cur_opr_tensor_cnt
;
size_t
m_cur_opr_blob_cnt
;
size_t
m_cur_opr_param_cnt
;
public:
ComputingGraph
&
graph
()
override
{
return
*
m_graph
;
}
const
GraphLoadConfig
&
config
()
const
override
{
return
*
m_loader
->
m_cur_load_config
;
}
std
::
shared_ptr
<
HostTensorND
>
load_tensor
()
override
;
std
::
shared_ptr
<
DeviceTensorND
>
load_tensor_shared
()
override
;
void
load_single_opr
(
const
fbs
::
v2
::
Operator
*
opr
);
OprLoadContextImpl
(
GraphLoaderOSSV2
*
loader
,
uint32_t
version
)
:
OprLoadContextFlatBuffers
(
version
),
m_loader
{
loader
}
{
m_graph
=
loader
->
m_cur_load_config
->
comp_graph
;
if
(
!
m_graph
)
{
m_graph
=
ComputingGraph
::
make
();
}
auto
maker
=
[
this
]()
{
return
std
::
shared_ptr
<
OprLoadContext
>
{
std
::
shared_ptr
<
OprLoadContext
>
{},
this
};
};
auto
got
=
m_graph
->
options
().
user_data
.
get_user_data_or_create
<
OprLoadContext
>
(
maker
);
mgb_assert
(
got
==
this
);
}
~
OprLoadContextImpl
()
noexcept
{
auto
nr
=
m_graph
->
options
().
user_data
.
pop_user_data
<
OprLoadContext
>
();
mgb_assert
(
nr
==
1
);
}
Metadata
load_metadata
();
LoadResult
load_oprs
();
CompNode
load_comp_node
(
const
fbs
::
v2
::
CompNode
*
comp_node
);
void
load_middle_tensor
();
const
void
*
get_next_param
(
uint32_t
enumv
)
override
{
auto
type
=
static_cast
<
fbs
::
v2
::
OperatorParam
>
(
enumv
);
if
(
m_cur_opr_param_cnt
==
0
)
{
m_cur_opr_param_cnt
++
;
if
(
m_current_opr
->
param_type
()
==
type
)
{
return
m_current_opr
->
param
();
}
else
{
mgb_throw
(
SerializationError
,
"The param type is not match when load the opr."
);
}
}
mgb_throw
(
SerializationError
,
"When load multi param in one Operator, please use read_param(index) "
"interface. "
);
}
std
::
string
load_buf_with_len
()
override
{
mgb_assert
(
m_current_opr
->
custom_data
()
&&
m_cur_opr_blob_cnt
<
m_current_opr
->
custom_data
()
->
size
());
auto
blob
=
m_current_opr
->
custom_data
()
->
Get
(
m_cur_opr_blob_cnt
++
);
mgb_assert
(
blob
&&
blob
->
data
());
auto
data
=
blob
->
data
()
->
data
();
return
{
reinterpret_cast
<
const
char
*>
(
data
),
blob
->
data
()
->
size
()};
}
SharedBuffer
load_shared_buf_with_len
()
override
{
mgb_assert
(
m_current_opr
->
custom_data
()
&&
m_cur_opr_blob_cnt
<
m_current_opr
->
custom_data
()
->
size
());
auto
blob
=
m_current_opr
->
custom_data
()
->
Get
(
m_cur_opr_blob_cnt
++
);
mgb_assert
(
blob
&&
blob
->
data
());
auto
size
=
blob
->
data
()
->
size
();
std
::
shared_ptr
<
uint8_t
>
shptr
{
new
uint8_t
[
size
],
[](
uint8_t
*
p
)
{
delete
[]
p
;
}};
memcpy
(
shptr
.
get
(),
blob
->
data
()
->
data
(),
size
);
return
{
std
::
move
(
shptr
),
size
};
}
const
void
*
get_current_opr_data
()
override
{
return
reinterpret_cast
<
const
void
*>
(
m_current_opr
);
}
template
<
class
T
>
T
read_param
(
int
index
)
{
using
SourceType
=
typename
fbs
::
ParamConverter
<
T
>::
FlatBufferType
;
auto
enumv
=
fbs
::
OperatorParamTraits
<
SourceType
>::
enum_value
;
auto
type
=
static_cast
<
fbs
::
v2
::
OperatorParam
>
(
enumv
);
if
(
index
==
0
)
{
mgb_assert
(
m_current_opr
->
param_type
()
==
type
,
"Load param error, the param type is not right."
);
return
fbs
::
ParamConverter
<
T
>::
to_param
(
static_cast
<
const
SourceType
*>
(
m_current_opr
->
param
()));
}
else
{
int
addition_index
=
index
-
1
;
if
(
addition_index
>=
static_cast
<
int
>
(
m_current_opr
->
additional_params
()
->
size
()))
{
mgb_log_warn
(
"Model has no addition param of index %d, just construct a "
"default one."
,
addition_index
);
}
else
{
mgb_assert
(
m_current_opr
->
additional_params_type
()
->
Get
(
addition_index
)
==
type
,
"Load param error, the addition param type is not right."
);
return
fbs
::
ParamConverter
<
T
>::
to_param
(
static_cast
<
const
SourceType
*>
(
m_current_opr
->
additional_params
()
->
Get
(
addition_index
)));
}
}
}
};
}
// namespace serialization
}
// namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录