Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
270f1aa2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
270f1aa2
编写于
9月 15, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/serialization): add Accessor for OprLoader to fix BN output compatibility
GitOrigin-RevId: 3b95da02c8fa3cd2a6c6d47d7ede93b7b36aa3a7
上级
c0ccd0ea
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
98 addition
and
51 deletion
+98
-51
imperative/src/impl/ops/opr_attr.cpp
imperative/src/impl/ops/opr_attr.cpp
+2
-2
imperative/src/test/backward_graph.cpp
imperative/src/test/backward_graph.cpp
+2
-2
imperative/src/test/imperative.cpp
imperative/src/test/imperative.cpp
+1
-1
src/opr/impl/dnn/dnn.sereg.h
src/opr/impl/dnn/dnn.sereg.h
+8
-6
src/serialization/impl/opr_registry.cpp
src/serialization/impl/opr_registry.cpp
+17
-1
src/serialization/impl/opr_shallow_copy.cpp
src/serialization/impl/opr_shallow_copy.cpp
+1
-1
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+3
-2
src/serialization/include/megbrain/serialization/opr_registry.h
...rialization/include/megbrain/serialization/opr_registry.h
+21
-1
src/serialization/include/megbrain/serialization/sereg.h
src/serialization/include/megbrain/serialization/sereg.h
+43
-35
未找到文件。
imperative/src/impl/ops/opr_attr.cpp
浏览文件 @
270f1aa2
...
...
@@ -76,7 +76,7 @@ public:
}
};
cg
::
OperatorNodeBase
*
apply_on_var_node
(
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
attr
=
def
.
cast_final_safe
<
OprAttr
>
();
auto
config
=
attr
.
config
;
...
...
@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto
registry
=
serialization
::
OprRegistry
::
find_by_name
(
attr
.
type
);
mgb_assert
(
registry
,
"operator %s not found"
,
attr
.
type
.
c_str
());
OprParamsLoadContext
ctx
{
attr
.
param
,
inputs
[
0
]
->
owner_graph
()};
return
registry
->
loader
(
ctx
,
inputs
,
config
);
return
registry
->
loader
(
ctx
,
inputs
,
config
)
.
usable_output
()
;
}
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
opr
)
{
...
...
imperative/src/test/backward_graph.cpp
浏览文件 @
270f1aa2
...
...
@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) {
LogicalTensorDesc
inp
{
TensorLayout
{{
N
,
C
,
H
,
W
},
dtype
::
Float32
()},
cn
};
LogicalTensorDesc
stat
{
TensorLayout
{{
C
},
dtype
::
Float32
()},
cn
};
{
auto
op
=
OprAttr
::
make
(
"BatchNorm"
);
auto
op
=
OprAttr
::
make
(
"BatchNorm
V1
"
);
auto
&&
attr
=
op
->
cast_final_safe
<
OprAttr
>
();
Param
param
;
param
.
fwd_mode
=
Param
::
FwdMode
::
TRAINING
;
...
...
@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) {
{
false
,
false
,
false
,
false
,
false
,
true
});
}
{
auto
op
=
OprAttr
::
make
(
"BatchNorm"
);
auto
op
=
OprAttr
::
make
(
"BatchNorm
V1
"
);
auto
&&
attr
=
op
->
cast_final_safe
<
OprAttr
>
();
Param
param
;
param
.
fwd_mode
=
Param
::
FwdMode
::
TRAINING
;
...
...
imperative/src/test/imperative.cpp
浏览文件 @
270f1aa2
...
...
@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) {
}
TEST
(
TestImperative
,
BatchNorm
)
{
auto
op
=
OprAttr
::
make
(
"BatchNorm"
);
auto
op
=
OprAttr
::
make
(
"BatchNorm
V1
"
);
auto
&&
attr
=
op
->
cast_final_safe
<
OprAttr
>
();
using
Param
=
opr
::
BatchNorm
::
Param
;
Param
param
;
...
...
src/opr/impl/dnn/dnn.sereg.h
浏览文件 @
270f1aa2
...
...
@@ -16,14 +16,13 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
...
...
@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> {
}
};
// OprMaker in MGB_SEREG_OPR only support unique output opr
template
<
>
struct
OprMaker
<
opr
::
BatchNormBackward
,
6
>
{
using
Param
=
opr
::
BatchNormBackward
::
Param
;
...
...
@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> {
ComputingGraph
&
graph
,
const
OperatorNodeConfig
&
config
)
{
MGB_MARK_USED_VAR
(
graph
);
return
opr
::
BatchNormBackward
::
make
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
],
i
[
4
],
i
[
5
],
param
,
config
)[
0
]
return
opr
::
BatchNormBackward
::
make
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
],
i
[
4
],
i
[
5
],
param
,
config
)[
0
]
.
node
()
->
owner_opr
();
}
...
...
@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using
ConvBiasForwardV4
=
ConvBiasForward
;
MGB_SEREG_OPR
(
ConvBiasForwardV4
,
0
);
MGB_SEREG_OPR
(
BatchNorm
,
0
);
MGB_SEREG_OPR
(
BatchNormBackward
,
6
);
using
BatchNormV1
=
BatchNorm
;
using
BatchNormBackwardV1
=
BatchNormBackward
;
MGB_SEREG_OPR
(
BatchNormV1
,
0
);
MGB_SEREG_OPR
(
BatchNormBackwardV1
,
6
);
using
LocalShareForwardV1
=
LocalShareForward
;
using
LocalShareBackwardDataV1
=
LocalShareBackwardData
;
...
...
src/serialization/impl/opr_registry.cpp
浏览文件 @
270f1aa2
...
...
@@ -39,7 +39,7 @@ namespace {
return
inst
;
}
cg
::
OperatorNodeBase
*
dynamic_loader
(
OprWithOutputAccessor
dynamic_loader
(
OprLoadContext
&
ctx
,
const
cg
::
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
name
=
ctx
.
load_buf_with_len
();
...
...
@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() {
}
#endif
namespace
{
const
VarNodeArray
&
default_accessor
(
const
VarNodeArray
&
outputs
)
{
return
outputs
;
}
}
OprWithOutputAccessor
::
OprWithOutputAccessor
(
cg
::
OperatorNodeBase
*
opr
)
:
m_opr
(
opr
){
m_accessor
=
&
default_accessor
;
};
OprWithOutputAccessor
::
OprWithOutputAccessor
(
cg
::
OperatorNodeBase
*
opr
,
Accessor
accessor
)
:
OprWithOutputAccessor
(
opr
)
{
if
(
accessor
)
{
m_accessor
=
accessor
;
}
};
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/serialization/impl/opr_shallow_copy.cpp
浏览文件 @
270f1aa2
...
...
@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
registry
->
dumper
(
dumper
,
opr
);
OprLoadContextMemory
loader
{
opr
.
owner_graph
(),
dumper
};
return
registry
->
loader
(
loader
,
inputs
,
config
);
return
registry
->
loader
(
loader
,
inputs
,
config
)
.
opr
()
;
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/serialization/impl/serializer_oss.cpp
浏览文件 @
270f1aa2
...
...
@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
}
// call loader
auto
opr
=
registry
->
loader
(
*
this
,
inputs
,
config
);
auto
accessor
=
registry
->
loader
(
*
this
,
inputs
,
config
);
auto
opr
=
accessor
.
opr
();
// check opr type; note that:
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs
...
...
@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
opr
?
opr
->
dyn_typeinfo
()
->
name
:
nullptr
,
registry
->
type
->
name
);
// record output vars; read output names
size_t
i
=
0
;
for
(
auto
ovar
:
opr
->
output
())
{
for
(
auto
ovar
:
accessor
.
output
())
{
if
(
!
ovar
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
m_id2varnode
.
push_back
(
ovar
);
if
(
fbopr
->
output_name
())
{
...
...
src/serialization/include/megbrain/serialization/opr_registry.h
浏览文件 @
270f1aa2
...
...
@@ -19,16 +19,36 @@ namespace serialization {
class
OprDumpContext
;
class
OprLoadContext
;
class
OprShallowCopyContext
;
class
OprWithOutputAccessor
{
cg
::
OperatorNodeBase
*
m_opr
;
using
Accessor
=
thin_function
<
const
VarNodeArray
(
const
VarNodeArray
&
)
>
;
Accessor
m_accessor
;
public:
OprWithOutputAccessor
(
cg
::
OperatorNodeBase
*
opr
);
OprWithOutputAccessor
(
cg
::
OperatorNodeBase
*
opr
,
Accessor
accessor
);
VarNode
*
output
(
size_t
idx
)
const
{
return
output
().
at
(
idx
);
}
VarNodeArray
output
()
const
{
return
m_accessor
(
m_opr
->
output
());
}
VarNodeArray
usable_output
()
const
{
return
m_accessor
(
m_opr
->
usable_output
());
}
cg
::
OperatorNodeBase
*
opr
()
{
return
m_opr
;
}
};
//! dump opr internal params to OprDumpContext
using
OprDumper
=
thin_function
<
void
(
OprDumpContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr
)
>
;
//! load and restore operator from OprLoadContext
//! is also used by GraphLoadConfig.
using
OprLoader
=
thin_function
<
cg
::
OperatorNodeBase
*
(
OprLoadContext
&
ctx
,
const
cg
::
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
>
;
//! loader that can change opr output map for compatibility
using
OprLoaderWrapper
=
thin_function
<
OprWithOutputAccessor
(
OprLoadContext
&
ctx
,
const
cg
::
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
>
;
//! shallow copy function for a single operator
using
OprShallowCopy
=
thin_function
<
cg
::
OperatorNodeBase
*
(
const
OprShallowCopyContext
&
ctx
,
...
...
@@ -41,7 +61,7 @@ namespace serialization {
uint64_t
persist_type_id
;
std
::
string
name
;
OprDumper
dumper
;
OprLoader
loader
;
OprLoader
Wrapper
loader
;
OprShallowCopy
shallow_copy
;
//!< set to empty to use default impl
uint64_t
unversioned_type_id
;
...
...
src/serialization/include/megbrain/serialization/sereg.h
浏览文件 @
270f1aa2
...
...
@@ -167,16 +167,22 @@ namespace { \
/*!
* \brief register opr serialization methods
*/
#define MGB_SEREG_OPR(_cls, _arity) \
namespace { \
struct _OprReg##_cls { \
static void entry() { \
using Impl = ::mgb::serialization::OprLoadDumpImpl< \
_cls, _arity>; \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \
} \
}; \
} \
#define MGB_SEREG_OPR(_cls, _arity) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprReg##_cls { \
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor( \
Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
//! use to check type is complete or not, midout need a complete type
...
...
@@ -187,33 +193,35 @@ template <class T>
struct
IsComplete
<
T
,
decltype
(
void
(
sizeof
(
T
)))
>
:
std
::
true_type
{};
//! call OprRegistry::add with only loader, used for backward compatibility
#define MGB_SEREG_OPR_COMPAT(_name, _load) \
namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
struct _OprReg##_name { \
static cg::OperatorNodeBase* compat_loader( \
serialization::OprLoadContext& ctx, \
const cg::VarNodeArray& inputs, \
const OperatorNodeConfig& config) { \
return _load( \
static_cast<serialization::OprLoadContextRawPOD&>(ctx), \
inputs, config); \
} \
static void entry() { \
::mgb::serialization::OprRegistry::add( \
{nullptr, \
MGB_HASH_STR(#_name), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \
nullptr, \
compat_loader, \
{}, \
{}}); \
} \
}; \
} \
#define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \
namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
namespace ser = ::mgb::serialization; \
struct _OprReg##_name { \
static ser::OprWithOutputAccessor compat_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \
_accessor); \
} \
static void entry() { \
ser::OprRegistry::add({nullptr, \
MGB_HASH_STR(#_name), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \
nullptr, \
compat_loader, \
{}, \
{}}); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name)
#define MGB_SEREG_OPR_COMPAT(_name, _load) \
MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr)
/*!
* \brief use \p _copy to implement shallow copy for given operator
*/
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录