Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
accb2d8d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
accb2d8d
编写于
9月 29, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/serialize): fix flatbuffer compatibility issues
GitOrigin-RevId: e4771d6bc43a987a7fe725b5949b77da8769815d
上级
5b1383e0
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
120 addition
and
20 deletion
+120
-20
dnn/scripts/gen_flatbuffers_converter.py
dnn/scripts/gen_flatbuffers_converter.py
+0
-3
dnn/scripts/gen_flatbuffers_schema.py
dnn/scripts/gen_flatbuffers_schema.py
+0
-3
dnn/scripts/gen_tablegen.py
dnn/scripts/gen_tablegen.py
+0
-3
src/opr/impl/basic_arith.sereg.h
src/opr/impl/basic_arith.sereg.h
+43
-1
src/opr/test/blas.cpp
src/opr/test/blas.cpp
+36
-1
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+40
-8
src/serialization/include/megbrain/serialization/sereg.h
src/serialization/include/megbrain/serialization/sereg.h
+1
-1
未找到文件。
dnn/scripts/gen_flatbuffers_converter.py
浏览文件 @
accb2d8d
...
...
@@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase):
self
.
_last_param
=
p
self
.
_param_fields
=
[]
self
.
_fb_fields
=
[
"builder"
]
if
p
.
is_legacy
:
self
.
_skip_current_param
=
True
return
self
.
_write
(
"template<>
\n
struct ParamConverter<megdnn::param::%s> {"
,
p
.
name
,
indent
=
1
)
self
.
_write
(
"using MegDNNType = megdnn::param::%s;"
,
p
.
name
)
...
...
dnn/scripts/gen_flatbuffers_schema.py
浏览文件 @
accb2d8d
...
...
@@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase):
def
_on_param_begin
(
self
,
p
):
self
.
_last_param
=
p
self
.
_cur_const_val
=
{}
if
p
.
is_legacy
:
self
.
_skip_current_param
=
True
return
self
.
_write_doc
(
p
.
name
)
self
.
_write
(
"table %s {"
,
p
.
name
,
indent
=
1
)
...
...
dnn/scripts/gen_tablegen.py
浏览文件 @
accb2d8d
...
...
@@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase):
def
_on_param_begin
(
self
,
p
):
self
.
_last_param
=
p
if
p
.
is_legacy
:
self
.
_skip_current_param
=
True
return
self
.
_packed
=
True
self
.
_current_tparams
=
[]
self
.
_const
=
set
()
...
...
src/opr/impl/basic_arith.sereg.h
浏览文件 @
accb2d8d
...
...
@@ -62,6 +62,37 @@ struct PersistentAddUpdateParam {
}
// namespace opr_add_update
// Old SerializedDType used in MegBrain 7.22.0 - 7.23.1
// Should be kept as-is even if there are new dtypes.
struct
SerializedDTypeV1
{
static
constexpr
uint32_t
TAG
=
megdnn
::
param
::
FakeSerializedDType
::
TAG
;
DTypeEnum
enumv
;
union
{
megdnn
::
DTypeParam
<
dtype
::
Quantized8Asymm
>
Quantized8Asymm
;
megdnn
::
DTypeParam
<
dtype
::
QuantizedS8
>
QuantizedS8
;
megdnn
::
DTypeParam
<
dtype
::
QuantizedS32
>
QuantizedS32
;
}
param
;
operator
DType
()
const
{
switch
(
enumv
)
{
#define cb(_dt) \
case DTypeEnum::_dt: \
return DType::from_enum(enumv);
MEGDNN_FOREACH_DTYPE_NAME
(
cb
)
#undef cb
case
DTypeEnum
::
Quantized8Asymm
:
return
dtype
::
Quantized8Asymm
{
param
.
Quantized8Asymm
};
case
DTypeEnum
::
QuantizedS8
:
return
dtype
::
QuantizedS8
{
param
.
QuantizedS8
};
case
DTypeEnum
::
QuantizedS32
:
return
dtype
::
QuantizedS32
{
param
.
QuantizedS32
};
default:
mgb_assert
(
false
,
"unexpected old serialized dtype: invalid enumv %d"
,
static_cast
<
uint32_t
>
(
enumv
));
}
}
};
template
<
>
struct
OprPersistentParam
<
opr
::
AddUpdate
>
{
using
Param
=
opr_add_update
::
PersistentAddUpdateParam
;
...
...
@@ -104,7 +135,18 @@ struct ParamConverter<megdnn::DType> {
return
fbs
::
intl
::
build_dtype
(
builder
,
dtype
);
}
};
}
// namespace fbs
template
<
>
struct
ParamConverter
<
SerializedDTypeV1
>
{
using
FlatBufferType
=
SerializedDTypeV1
;
static
SerializedDTypeV1
to_param
(
const
FlatBufferType
*
fb
)
{
mgb_assert
(
false
,
"You are calling SerializedDTypeV1 in flatbuffer, you should not call "
"here, this code is just to avoid compiling errors, but not be used in "
"flatbuffer."
);
}
};
};
// namespace fbs
#endif
template
<
>
...
...
src/opr/test/blas.cpp
浏览文件 @
accb2d8d
...
...
@@ -16,6 +16,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
...
...
@@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) {
}
#endif
#if MGB_ENABLE_FBS_SERIALIZATION
TEST
(
TestOprDNN
,
MatrixMulSerialization
)
{
using
namespace
serialization
;
auto
fname
=
output_file
(
"MatrixMulSerializationTest"
);
auto
dump
=
[
&
]()
{
opr
::
MatrixMul
::
Param
param
;
auto
cn
=
CompNode
::
load
(
"cpu0"
);
auto
graph
=
ComputingGraph
::
make
();
HostTensorND
a_host
{
cn
,
{
24
,
24
},
dtype
::
Float32
()};
HostTensorND
b_host
{
cn
,
{
24
,
24
},
dtype
::
Float32
()};
auto
a
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
a_host
);
auto
b
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
b_host
);
auto
opr
=
opr
::
MatrixMul
::
make
(
a
,
b
,
param
,
{});
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS
);
auto
rst
=
dumper
->
dump
({
opr
});
ASSERT_EQ
(
rst
.
outputs
.
size
(),
1u
);
};
auto
load
=
[
&
]()
{
auto
loader
=
GraphLoader
::
make
(
InputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS
);
auto
rst
=
loader
->
load
();
ASSERT_EQ
(
rst
.
output_var_list
.
size
(),
1u
);
auto
opr
=
rst
.
output_var_list
[
0
].
node
()
->
owner_opr
();
ASSERT_TRUE
(
opr
->
same_type
<
opr
::
MatrixMul
>
());
};
dump
();
load
();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
//
//
\ No newline at end of file
src/serialization/impl/serializer_oss.cpp
浏览文件 @
accb2d8d
...
...
@@ -47,7 +47,13 @@ namespace {
constexpr
uint32_t
MGB_VERSION
=
(
MGE_MAJOR
*
1000
+
MGE_MINOR
)
*
100
+
MGE_PATCH
;
constexpr
uint32_t
MGB_MAGIC
=
0x5342474D
;
constexpr
uint32_t
MGB_MAGIC
=
0x4342474D
;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr
uint32_t
MAGIC_V0
=
0x5342474D
;
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the
// old magic(MAGIC_V0) is false.
bool
magic_compare
=
true
;
template
<
typename
T
>
bool
contains_any_in_set
(
const
SmallVector
<
T
>&
list
,
const
ThinHashSet
<
T
>&
set
)
{
...
...
@@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso
}
}
//! feature bits for backward compatibility; default value should be 0
struct
FeatureBits64
{
//! reserved for new fields
uint64_t
:
64
;
static
void
write
(
OutputFile
&
fout
)
{
static_assert
(
sizeof
(
FeatureBits64
)
==
8
,
"bad feature bits"
);
FeatureBits64
fb64
;
memset
(
&
fb64
,
0
,
sizeof
(
fb64
));
fout
.
write
(
&
fb64
,
8
);
}
};
}
// namespace
namespace
mgb
{
...
...
@@ -266,7 +284,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
fbs
::
OperatorBuilder
builder
(
m_builder
);
builder
.
add_type_id
(
registry
->
unversioned
_type_id
);
builder
.
add_type_id
(
registry
->
persist
_type_id
);
builder
.
add_inputs
(
inputs
);
if
(
m_config
.
keep_opr_priority
)
{
builder
.
add_priority
(
opr
->
node_prop
().
attribute
().
priority
);
...
...
@@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint32_t
magic
=
MGB_MAGIC
;
m_file
->
write
(
&
magic
,
sizeof
(
magic
));
// write FeatureBits
FeatureBits64
::
write
(
*
m_file
);
// Padding
uint32_t
reserved
=
0
;
m_file
->
write
(
&
reserved
,
sizeof
(
reserved
));
...
...
@@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) {
class
GraphLoaderOSS
final
:
public
GraphLoader
{
const
LoadConfig
*
m_cur_load_config
=
nullptr
;
std
::
unique_ptr
<
InputFile
>
m_file
;
FeatureBits64
m_feature_bits
;
SharedBuffer
m_graph_buf
{{},
0
};
const
fbs
::
Graph
*
m_graph
;
SharedTensorIDMap
m_shared_tensor_map
;
...
...
@@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb
}
config
.
comp_node_arr
(
comp_node_arr
);
}
auto
registry
=
OprRegistry
::
find_by_unversioned_id
(
fbopr
->
type_id
());
const
OprRegistry
*
registry
;
if
(
magic_compare
)
{
registry
=
OprRegistry
::
find_by_id
(
fbopr
->
type_id
());
}
else
{
registry
=
OprRegistry
::
find_by_unversioned_id
(
fbopr
->
type_id
());
}
mgb_throw_if
(
!
registry
,
SerializationError
,
"failed to find opr with type %s, use python env "
...
...
@@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
uint32_t
magic
;
m_file
->
read
(
&
magic
,
sizeof
(
magic
));
mgb_throw_if
(
magic
!=
MGB_MAGIC
,
SerializationError
,
"wrong magic: wanted %#08x, actual %#08x (not a invalid fbs "
(
magic
!=
MGB_MAGIC
)
&&
(
magic
!=
MAGIC_V0
)
,
SerializationError
,
"wrong magic: wanted %#08x
or %#08x
, actual %#08x (not a invalid fbs "
"model?)"
,
MGB_MAGIC
,
magic
);
MGB_MAGIC
,
MAGIC_V0
,
magic
);
if
(
magic
==
MGB_MAGIC
)
{
// read FeatureBits
magic_compare
=
true
;
m_file
->
read
(
&
m_feature_bits
,
8
);
}
else
{
magic_compare
=
false
;
}
m_file
->
skip
(
4
);
uint64_t
offset_to_fbs
;
...
...
@@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) {
uint64_t
magic_with_reserved
=
0
;
file
.
read
(
&
magic_with_reserved
,
sizeof
(
magic_with_reserved
));
file
.
skip
(
-
sizeof
(
magic_with_reserved
));
return
magic_with_reserved
==
MGB_MAGIC
;
return
(
magic_with_reserved
==
MGB_MAGIC
)
||
(
magic_with_reserved
==
MAGIC_V0
)
;
}
}
// namespace serialization
...
...
src/serialization/include/megbrain/serialization/sereg.h
浏览文件 @
accb2d8d
...
...
@@ -199,7 +199,7 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
static ser::OprWithOutputAccessor compat_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
auto&& ctx_ = static_cast<ser::OprLoadContext
RawPOD&>(ctx);
\
auto&& ctx_ = static_cast<ser::OprLoadContext
&>(ctx);
\
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \
} \
static void entry() { \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录