Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4ab5f970
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4ab5f970
编写于
5月 12, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(build): fix ci error
GitOrigin-RevId: 9cbf64dda27c8d99af009c13b34deadfa9655637
上级
b9a69323
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
43 addition
and
12 deletion
+43
-12
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+2
-2
src/opr/impl/dnn/dnn.sereg.v2.h
src/opr/impl/dnn/dnn.sereg.v2.h
+2
-0
src/opr/impl/io.sereg.v2.h
src/opr/impl/io.sereg.v2.h
+1
-0
src/serialization/impl/schema_v2.fbs
src/serialization/impl/schema_v2.fbs
+8
-2
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+28
-5
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
...zation/include/megbrain/serialization/oss_opr_load_dump.h
+2
-3
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
4ab5f970
...
...
@@ -614,8 +614,8 @@ class trace:
input_transform: a python expression to transform the input data.
Example: data / np.std(data)
dump_format: using different dump formats. the open source MegEngine
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
internal MegEngine have an other choice of internal proprietary formats
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
internal MegEngine have an other choice of internal proprietary formats
Keyword Arguments:
...
...
src/opr/impl/dnn/dnn.sereg.v2.h
浏览文件 @
4ab5f970
#pragma once
#include "megbrain/graph/symbol_var.h"
#include "megdnn/oprs/general.h"
#if MGB_ENABLE_FBS_SERIALIZATION
...
...
src/opr/impl/io.sereg.v2.h
浏览文件 @
4ab5f970
#pragma once
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/dnn/softmax.h"
...
...
src/serialization/impl/schema_v2.fbs
浏览文件 @
4ab5f970
...
...
@@ -5,7 +5,7 @@ include "mgb_cpp_opr.fbs";
namespace mgb.serialization.fbs.v2;
file_identifier "mg
e
2";
file_identifier "mg
v
2";
table CompNode {
logical_locator:string;
...
...
@@ -105,7 +105,7 @@ union OperatorParam {
param.OptionalAxisV1 = 54,
param.ExecutionPolicy = 55,
param.AssertEqual = 56,
param.FpgaConv
= 57,
Reserved0
= 57,
param.CollectiveComm = 58,
param.CondExecPred = 59,
param.CondExecPredLogical = 60,
...
...
@@ -197,6 +197,11 @@ table OutputVar {
original_id:uint;
}
table OutputAlias {
id:uint;
name:string;
}
table Model {
/// the megengine version when serialize the model
mge_version:uint;
...
...
@@ -213,6 +218,7 @@ table Model {
middle_tensors:[MiddleTensor];
output_vars_idx:[OutputVar];
output_alias:[OutputAlias];
nr_shared_tensor:uint;
/// the Metadata to storage the custom data or some flags
...
...
src/serialization/impl/serializer_oss_v2.cpp
浏览文件 @
4ab5f970
...
...
@@ -400,6 +400,18 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
output_vars_idx
.
push_back
(
foutput_vars_idx
);
}
auto
fb_output_vars
=
m_builder
.
CreateVector
(
output_vars_idx
);
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
v2
::
OutputAlias
>>
output_vars_alias
;
if
(
m_config
.
alias_name_map
.
size
()
>
0
)
{
for
(
auto
&&
pair
:
m_config
.
alias_name_map
)
{
std
::
string
name
;
SymbolVar
var
;
std
::
tie
(
name
,
var
)
=
pair
;
auto
fbs_name
=
m_builder
.
CreateSharedString
(
name
);
output_vars_alias
.
push_back
(
fbs
::
v2
::
CreateOutputAlias
(
m_builder
,
var
.
node
()
->
id
(),
fbs_name
));
}
}
auto
fbs_output_alias
=
m_builder
.
CreateVector
(
output_vars_alias
);
auto
fb_mid_tensor
=
m_builder
.
CreateVector
(
m_model_middle_tensors
);
fbs
::
v2
::
ModelBuilder
model
(
m_builder
);
...
...
@@ -407,6 +419,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
model
.
add_oprs
(
fb_oprs
);
model
.
add_middle_tensors
(
fb_mid_tensor
);
model
.
add_output_vars_idx
(
fb_output_vars
);
model
.
add_output_alias
(
fbs_output_alias
);
model
.
add_nr_shared_tensor
(
m_nr_shared_tensor
);
model
.
add_metadata
(
fbmeta
);
m_builder
.
FinishSizePrefixed
(
model
.
Finish
(),
fbs
::
v2
::
ModelIdentifier
());
...
...
@@ -469,7 +482,7 @@ void GraphDumperOSSV2::dump_tensor(
if
(
dumper
)
{
mgb_log_warn
(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value dumper"
);
"user tensor value dumper
callback.
"
);
}
data
=
m_builder
.
CreateVector
(
reinterpret_cast
<
uint8_t
*>
(
tensor
.
raw_ptr
()),
layout
.
span
().
high_byte
);
...
...
@@ -568,7 +581,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
if
(
loader
)
{
mgb_log_warn
(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader"
);
"user tensor value loader
callback.
"
);
}
memcpy
(
ret
->
raw_ptr
(),
tensor
->
data
()
->
data
(),
tensor
->
data
()
->
size
());
}
...
...
@@ -677,15 +690,14 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
//! opr version must be exist
uint8_t
opr_version
=
fbopr
->
opr_version
();
auto
type_id
=
fbopr
->
type_id
();
auto
opr_type
=
fbopr
->
type
()
->
str
();
const
OprRegistryV2
*
registry
=
OprRegistryV2
::
versioned_find_by_id
(
type_id
,
opr_version
);
mgb_throw_if
(
!
registry
,
SerializationError
,
"failed to find opr with type %s
id is %zu
, use python env "
"failed to find opr with type %s , use python env "
"config.dump_registered_oprs() to get a dict that maps from "
"opr id to opr name"
,
fbopr
->
type
()
->
str
().
c_str
()
,
type_id
);
fbopr
->
type
()
->
str
().
c_str
());
// load inputs
VarNodeArray
inputs
;
...
...
@@ -817,6 +829,17 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
auto
metadata
=
ctx
.
load_metadata
();
auto
result
=
ctx
.
load_oprs
();
result
.
metadata
=
metadata
;
if
(
m_model
->
output_alias
()
&&
m_model
->
output_alias
()
->
size
()
>
0
)
{
auto
nr_alias
=
m_model
->
output_alias
()
->
size
();
result
.
output_var_list
.
resize
(
nr_alias
);
for
(
size_t
i
=
0
;
i
<
nr_alias
;
i
++
)
{
auto
output_alias
=
m_model
->
output_alias
()
->
Get
(
i
);
std
::
string
name
=
output_alias
->
name
()
->
str
();
size_t
id
=
output_alias
->
id
();
result
.
output_var_map
[
name
]
=
result
.
output_var_map_id
[
id
];
result
.
output_var_list
[
i
]
=
result
.
output_var_map_id
[
id
];
}
}
m_model_loaded
=
true
;
result
.
graph_compile_ahead
();
return
result
;
...
...
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
浏览文件 @
4ab5f970
...
...
@@ -233,9 +233,8 @@ public:
int
addition_index
=
index
-
1
;
if
(
addition_index
>=
static_cast
<
int
>
(
m_current_opr
->
additional_params
()
->
size
()))
{
mgb_log_warn
(
"Model has no addition param of index %d, just construct a "
"default one."
,
mgb_throw
(
SerializationError
,
"Model has no addition param of index %d."
,
addition_index
);
}
else
{
mgb_assert
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录