Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fa671d67
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fa671d67
编写于
6月 08, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): support get opr param json from serailzed param
GitOrigin-RevId: c4cabb6f700b722e61c5944bcd16cceab13f513d
上级
06886fd1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
135 addition
and
28 deletion
+135
-28
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+129
-28
src/plugin/include/megbrain/plugin/opr_footprint.h
src/plugin/include/megbrain/plugin/opr_footprint.h
+6
-0
未找到文件。
src/plugin/impl/opr_footprint.cpp
浏览文件 @
fa671d67
...
...
@@ -5,6 +5,7 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/pooling.h"
...
...
@@ -13,6 +14,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/internal/indexing_helper_sereg.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
...
...
@@ -20,6 +22,7 @@
#include "megbrain/opr/standalone/nms_opr.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/opr_load_dump.h"
#if MGB_ENABLE_JSON
#include "megdnn/opr_param_json.h"
#endif
...
...
@@ -488,12 +491,24 @@ uint64_t opr_footprint_func<opr::Host2DeviceCopy>(cg::OperatorNodeBase* opr) {
template
<
class
T
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
(
cg
::
OperatorNodeBase
*
opr
);
template
<
class
T
>
std
::
shared_ptr
<
json
::
Value
>
serial_param_json_func
(
serialization
::
OprLoadContextRawPOD
&
context
);
#define REGISTE_SERIAL_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> serial_param_json_func<opr::cls>( \
serialization::OprLoadContextRawPOD & context) { \
return opr::opr_param_to_json(context.read_param<opr::cls::Param>()); \
}
#define REGISTE_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
return opr::opr_param_to_json(opr->cast_final_safe<opr::cls>().param()); \
}
} \
REGISTE_SERIAL_PARAM_JSON_FUNC(cls)
REGISTE_PARAM_JSON_FUNC
(
Elemwise
)
REGISTE_PARAM_JSON_FUNC
(
ConvolutionForward
)
...
...
@@ -544,12 +559,12 @@ REGISTE_PARAM_JSON_FUNC(GaussianRNG)
REGISTE_PARAM_JSON_FUNC
(
Linspace
)
REGISTE_PARAM_JSON_FUNC
(
Eye
)
REGISTE_PARAM_JSON_FUNC
(
CvtColor
)
REGISTE_PARAM_JSON_FUNC
(
LayerNormBackward
)
REGISTE_PARAM_JSON_FUNC
(
AdaptivePoolingBackward
)
REGISTE_PARAM_JSON_FUNC
(
DropoutBackward
)
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Dimshuffle
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
param
=
opr
->
cast_final_safe
<
opr
::
Dimshuffle
>
().
param
();
std
::
shared_ptr
<
json
::
Value
>
dimshuffle_param2json
(
const
opr
::
Dimshuffle
::
Param
&
param
)
{
auto
pattern
=
json
::
Array
::
make
();
for
(
size_t
i
=
0
;
i
<
param
.
pattern_len
;
i
++
)
pattern
->
add
(
json
::
NumberInt
::
make
(
param
.
pattern
[
i
]));
...
...
@@ -561,10 +576,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
AxisAddRemov
e
>
(
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Dimshuffl
e
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
param
=
opr
->
cast_final_safe
<
opr
::
AxisAddRemove
>
().
param
();
auto
param
=
opr
->
cast_final_safe
<
opr
::
Dimshuffle
>
().
param
();
return
dimshuffle_param2json
(
param
);
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
serial_param_json_func
<
opr
::
Dimshuffle
>
(
serialization
::
OprLoadContextRawPOD
&
context
)
{
return
dimshuffle_param2json
(
context
.
read_param
<
opr
::
Dimshuffle
::
Param
>
());
}
std
::
shared_ptr
<
json
::
Value
>
axis_add_remove_param2json
(
const
opr
::
AxisAddRemove
::
Param
&
param
)
{
auto
desc
=
json
::
Array
::
make
();
for
(
size_t
i
=
0
;
i
<
param
.
nr_desc
;
i
++
)
{
auto
axisdesc
=
param
.
desc
[
i
];
...
...
@@ -581,6 +605,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
});
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
AxisAddRemove
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
param
=
opr
->
cast_final_safe
<
opr
::
AxisAddRemove
>
().
param
();
return
axis_add_remove_param2json
(
param
);
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
serial_param_json_func
<
opr
::
AxisAddRemove
>
(
serialization
::
OprLoadContextRawPOD
&
context
)
{
return
axis_add_remove_param2json
(
context
.
read_param
<
opr
::
AxisAddRemove
::
Param
>
());
}
std
::
shared_ptr
<
json
::
Value
>
indexing_param_to_json
(
const
std
::
vector
<
opr
::
indexing
::
AxisIndexer
>&
indices
)
{
auto
desc
=
json
::
Array
::
make
();
...
...
@@ -596,12 +633,29 @@ std::shared_ptr<json::Value> indexing_param_to_json(
return
desc
;
}
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \
} \
template <> \
std::shared_ptr<json::Value> serial_param_json_func<opr::cls>( \
serialization::OprLoadContextRawPOD & context) { \
auto indices = context.read_param<serialization::IndexDescMaskDump>(); \
auto desc = json::Array::make(); \
for (size_t i = 0; i < indices.nr_item; i++) { \
auto&& index = indices.items[i]; \
desc->add(json::Object::make({ \
{"axis", json::NumberInt::make(index.axis)}, \
{"begin", json::NumberInt::make(index.begin)}, \
{"end", json::NumberInt::make(index.end)}, \
{"step", json::NumberInt::make(index.step)}, \
{"idx", json::NumberInt::make(index.idx)}, \
})); \
} \
return desc; \
}
REGISTE_INDEXING_PARAM_JSON_FUNC
(
Subtensor
);
...
...
@@ -617,14 +671,11 @@ REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
BatchedIncrMeshIndexing
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
BatchedSetMeshIndexing
);
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Reshape
>
(
cg
::
OperatorNodeBase
*
opr
)
{
std
::
shared_ptr
<
json
::
Value
>
reshape_param2json
(
const
opr
::
Reshape
::
Param
&
param
)
{
auto
desc
=
json
::
Array
::
make
();
auto
axis_param
=
opr
->
cast_final_safe
<
opr
::
Reshape
>
().
param
();
if
(
axis_param
.
axis
!=
axis_param
.
MAX_NDIM
)
{
if
(
param
.
axis
!=
param
.
MAX_NDIM
)
{
return
json
::
Object
::
make
({
{
"axis"
,
json
::
NumberInt
::
make
(
axis_
param
.
axis
)},
{
"axis"
,
json
::
NumberInt
::
make
(
param
.
axis
)},
});
}
else
{
return
json
::
Object
::
make
();
...
...
@@ -632,13 +683,24 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
GetVarS
hape
>
(
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Res
hape
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
axis_param
=
opr
->
cast_final_safe
<
opr
::
Reshape
>
().
param
();
return
reshape_param2json
(
axis_param
);
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
serial_param_json_func
<
opr
::
Reshape
>
(
serialization
::
OprLoadContextRawPOD
&
context
)
{
return
reshape_param2json
(
context
.
read_param
<
opr
::
Reshape
::
Param
>
());
}
std
::
shared_ptr
<
json
::
Value
>
getvarshape_param2json
(
const
opr
::
GetVarShape
::
Param
&
param
)
{
auto
desc
=
json
::
Array
::
make
();
auto
axis_param
=
opr
->
cast_final_safe
<
opr
::
GetVarShape
>
().
param
();
if
(
axis_param
.
axis
!=
axis_param
.
MAX_NDIM
)
{
if
(
param
.
axis
!=
param
.
MAX_NDIM
)
{
return
json
::
Object
::
make
({
{
"axis"
,
json
::
NumberInt
::
make
(
axis_
param
.
axis
)},
{
"axis"
,
json
::
NumberInt
::
make
(
param
.
axis
)},
});
}
else
{
return
json
::
Object
::
make
();
...
...
@@ -646,15 +708,39 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
standalone
::
NMSKeep
>
(
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
GetVarShape
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
nms_param
=
opr
->
cast_final_safe
<
opr
::
standalone
::
NMSKeep
>
().
param
();
auto
axis_param
=
opr
->
cast_final_safe
<
opr
::
GetVarShape
>
().
param
();
return
getvarshape_param2json
(
axis_param
);
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
serial_param_json_func
<
opr
::
GetVarShape
>
(
serialization
::
OprLoadContextRawPOD
&
context
)
{
return
getvarshape_param2json
(
context
.
read_param
<
opr
::
GetVarShape
::
Param
>
());
}
std
::
shared_ptr
<
json
::
Value
>
nmskeep_param2json
(
const
opr
::
standalone
::
NMSKeep
::
Param
&
param
)
{
return
json
::
Object
::
make
({
{
"iou_thresh"
,
json
::
Number
::
make
(
nms_
param
.
iou_thresh
)},
{
"max_output"
,
json
::
Number
::
make
(
nms_
param
.
max_output
)},
{
"iou_thresh"
,
json
::
Number
::
make
(
param
.
iou_thresh
)},
{
"max_output"
,
json
::
Number
::
make
(
param
.
max_output
)},
});
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
standalone
::
NMSKeep
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
nms_param
=
opr
->
cast_final_safe
<
opr
::
standalone
::
NMSKeep
>
().
param
();
return
nmskeep_param2json
(
nms_param
);
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
serial_param_json_func
<
opr
::
standalone
::
NMSKeep
>
(
serialization
::
OprLoadContextRawPOD
&
context
)
{
return
nmskeep_param2json
(
context
.
read_param
<
opr
::
standalone
::
NMSKeep
::
Param
>
());
}
#endif // MGB_ENABLE_JSON
}
// namespace
...
...
@@ -675,6 +761,9 @@ void OprFootprint::add_single_param_json() {
auto
&&
record
=
m_type2param_json
.
emplace
(
OprType
::
typeinfo
(),
opr_param_json_func
<
OprType
>
);
mgb_assert
(
record
.
second
,
"duplicate opr typeinfo"
);
auto
&&
record1
=
m_type2serialparam_json
.
emplace
(
OprType
::
typeinfo
(),
serial_param_json_func
<
OprType
>
);
mgb_assert
(
record1
.
second
,
"duplicate opr typeinfo"
);
}
#endif
...
...
@@ -767,6 +856,9 @@ void OprFootprint::init_all_footprints() {
add_single_param_json
<
opr
::
Eye
>
();
add_single_param_json
<
opr
::
standalone
::
NMSKeep
>
();
add_single_param_json
<
opr
::
CvtColor
>
();
add_single_param_json
<
opr
::
LayerNormBackward
>
();
add_single_param_json
<
opr
::
AdaptivePoolingBackward
>
();
add_single_param_json
<
opr
::
DropoutBackward
>
();
#endif
}
...
...
@@ -814,6 +906,15 @@ std::shared_ptr<json::Value> OprFootprint::get_param_json(cg::OperatorNodeBase*
return
json
::
Object
::
make
();
}
std
::
shared_ptr
<
json
::
Value
>
OprFootprint
::
get_serial_param_json
(
Typeinfo
*
type
,
serialization
::
OprLoadContextRawPOD
&
context
)
{
auto
param_trait
=
m_type2serialparam_json
.
find
(
type
);
if
(
param_trait
!=
m_type2serialparam_json
.
end
())
{
return
(
param_trait
->
second
)(
context
);
}
return
json
::
Object
::
make
();
}
std
::
shared_ptr
<
json
::
Value
>
OprFootprint
::
Result
::
to_json
()
const
{
using
namespace
json
;
std
::
shared_ptr
<
Value
>
comp
;
...
...
src/plugin/include/megbrain/plugin/opr_footprint.h
浏览文件 @
fa671d67
#pragma once
#include "megbrain/graph.h"
#include "megbrain/serialization/opr_load_dump.h"
namespace
mgb
{
...
...
@@ -14,7 +15,10 @@ class OprFootprint {
#if MGB_ENABLE_JSON
using
ParamJsonTrait
=
thin_function
<
std
::
shared_ptr
<
json
::
Value
>
(
cg
::
OperatorNodeBase
*
)
>
;
using
SerialParamJsonTrait
=
thin_function
<
std
::
shared_ptr
<
json
::
Value
>
(
serialization
::
OprLoadContextRawPOD
&
)
>
;
ThinHashMap
<
Typeinfo
*
,
ParamJsonTrait
>
m_type2param_json
;
ThinHashMap
<
Typeinfo
*
,
SerialParamJsonTrait
>
m_type2serialparam_json
;
#endif
//! add single footprint calculator for associated opr type.
...
...
@@ -70,6 +74,8 @@ public:
#if MGB_ENABLE_JSON
MGE_WIN_DECLSPEC_FUC
std
::
shared_ptr
<
json
::
Value
>
get_param_json
(
cg
::
OperatorNodeBase
*
opr
);
MGE_WIN_DECLSPEC_FUC
std
::
shared_ptr
<
json
::
Value
>
get_serial_param_json
(
Typeinfo
*
type
,
serialization
::
OprLoadContextRawPOD
&
context
);
//! get opr foot print and graph exec info
//! the function will recompile graph, AsyncExecutable compiled before will
//! be invalid
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录