Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
20e8541b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
20e8541b
编写于
8月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): bind fallback impl on first op method call
GitOrigin-RevId: 82ae1e32052f274dea67ced95dc6ab694883425b
上级
18274e02
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
175 addition
and
107 deletion
+175
-107
imperative/src/impl/op_trait.cpp
imperative/src/impl/op_trait.cpp
+39
-34
imperative/src/impl/op_trait.h
imperative/src/impl/op_trait.h
+112
-53
imperative/src/impl/tensor_sanity_check.cpp
imperative/src/impl/tensor_sanity_check.cpp
+24
-20
未找到文件。
imperative/src/impl/op_trait.cpp
浏览文件 @
20e8541b
...
...
@@ -38,6 +38,38 @@ StaticData& static_data() {
return
data
;
}
void
OpMethFallback
::
impl
(
ApplyOnPhysicalTensor
&
func
,
op_meth_tag
::
ApplyOnPhysicalTensor
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
apply_on_physical_tensor
);
}
void
OpMethFallback
::
impl
(
Execute
&
func
,
op_meth_tag
::
Execute
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
execute
);
}
void
OpMethFallback
::
impl
(
InferOutputMemDesc
&
func
,
op_meth_tag
::
InferOutputMemDesc
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
infer_output_mem_desc
);
}
void
OpMethFallback
::
impl
(
InferOutputAttrsFallible
&
func
,
op_meth_tag
::
InferOutputAttrsFallible
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
infer_output_attrs_fallible
);
}
void
OpMethFallback
::
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
)
{
func
.
Base
::
operator
=
(
proxy_graph_detail
::
make_backward_graph
);
}
void
OpMethFallback
::
impl
(
DecideDispatchMode
&
func
,
op_meth_tag
::
DecideDispatchMode
)
{
static
auto
decide_dispatch_mode
=
[](
const
OpDef
&
,
const
SmallVector
<
LogicalTensorDesc
>&
)
{
return
DispatchMode
::
KERNEL
;
};
func
.
Base
::
operator
=
(
decide_dispatch_mode
);
}
void
OpMethFallback
::
impl
(
MakeNameFunc
&
func
,
op_meth_tag
::
MakeNameFunc
)
{
static
auto
make_name
=
[](
const
OpDef
&
def
)
->
std
::
string
{
return
def
.
trait
()
->
name
;
};
func
.
Base
::
operator
=
(
make_name
);
}
}
// detail
OpTrait
::
OpTrait
(
const
char
*
name_
)
:
name
(
name_
)
{}
...
...
@@ -66,44 +98,17 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}
DispatchMode
fallback_decide_dispatch_mode
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
return
KERNEL
;
}
OpTraitRegistry
&
OpTraitRegistry
::
fallback
()
{
if
(
trait
->
apply_on_var_node
)
{
// fallback to proxy graph impl
if
(
!
trait
->
apply_on_physical_tensor
)
{
trait
->
apply_on_physical_tensor
=
proxy_graph_detail
::
apply_on_physical_tensor
;
}
if
(
!
trait
->
execute
)
{
trait
->
execute
=
proxy_graph_detail
::
execute
;
}
if
(
!
trait
->
infer_output_mem_desc
)
{
trait
->
infer_output_mem_desc
=
proxy_graph_detail
::
infer_output_mem_desc
;
}
if
(
!
trait
->
infer_output_attrs_fallible
)
{
trait
->
infer_output_attrs_fallible
=
proxy_graph_detail
::
infer_output_attrs_fallible
;
}
if
(
!
trait
->
make_backward_graph
)
{
trait
->
make_backward_graph
=
proxy_graph_detail
::
make_backward_graph
;
}
}
if
(
!
trait
->
decide_dispatch_mode
)
{
trait
->
decide_dispatch_mode
=
fallback_decide_dispatch_mode
;
}
if
(
!
trait
->
make_name
)
{
static
auto
make_name
=
[](
const
OpDef
&
def
)
->
std
::
string
{
return
def
.
trait
()
->
name
;
};
trait
->
make_name
=
make_name
;
trait
->
apply_on_physical_tensor
.
allow_fallback
=
true
;
trait
->
execute
.
allow_fallback
=
true
;
trait
->
infer_output_mem_desc
.
allow_fallback
=
true
;
trait
->
infer_output_attrs_fallible
.
allow_fallback
=
true
;
trait
->
make_backward_graph
.
allow_fallback
=
true
;
}
trait
->
decide_dispatch_mode
.
allow_fallback
=
true
;
trait
->
make_name
.
allow_fallback
=
true
;
return
*
this
;
}
...
...
imperative/src/impl/op_trait.h
浏览文件 @
20e8541b
...
...
@@ -15,21 +15,10 @@
namespace
mgb
{
namespace
imperative
{
namespace
detail
{
template
<
typename
Signature
>
template
<
typename
Tag
,
typename
Signature
>
struct
OpMeth
;
template
<
typename
RType
,
typename
...
Args
>
struct
OpMeth
<
RType
(
Args
...)
>
:
public
thin_function
<
RType
(
Args
...)
>
{
using
Base
=
thin_function
<
RType
(
Args
...)
>
;
using
Base
::
Base
;
RType
operator
()(
Args
...
args
)
const
{
if
(
!
this
->
Base
::
operator
bool
())
{
mgb_throw
(
MegBrainError
,
"Not Implemented"
);
}
return
this
->
Base
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
}
};
template
<
typename
T
>
struct
ToVarNodeArray
:
std
::
false_type
{};
template
<
>
...
...
@@ -58,28 +47,95 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
};
}
// namespace detail
using
OpDefMaker
=
detail
::
OpMeth
<
decltype
(
OpDef
::
make_from_op_node
)
>
;
using
DecideDispatchMode
=
detail
::
OpMeth
<
decltype
(
OpDef
::
decide_dispatch_mode
)
>
;
using
ApplyOnPhysicalTensor
=
detail
::
OpMeth
<
decltype
(
OpDef
::
apply_on_physical_tensor
)
>
;
using
InferOutputMemDesc
=
detail
::
OpMeth
<
decltype
(
OpDef
::
infer_output_mem_desc
)
>
;
using
Execute
=
detail
::
OpMeth
<
decltype
(
OpDef
::
execute
)
>
;
using
ApplyOnDeviceTensorND
=
detail
::
OpMeth
<
decltype
(
OpDef
::
apply_on_device_tensornd
)
>
;
using
ApplyOnVarNode
=
detail
::
OpMeth
<
decltype
(
OpDef
::
apply_on_var_node
)
>
;
using
InferOutputAttrsFallible
=
detail
::
OpMeth
<
decltype
(
OpDef
::
infer_output_attrs_fallible
)
>
;
using
GradMaker
=
detail
::
OpMeth
<
decltype
(
OpDef
::
make_backward_graph
)
>
;
using
Props
=
detail
::
OpMeth
<
decltype
(
OpDef
::
props
)
>
;
using
HashFunc
=
detail
::
OpMeth
<
size_t
(
const
OpDef
&
)
>
;
using
IsSame
=
detail
::
OpMeth
<
bool
(
const
OpDef
&
,
const
OpDef
&
)
>
;
using
MakeNameFunc
=
detail
::
OpMeth
<
std
::
string
(
const
OpDef
&
)
>
;
// clang-format off
#define OpMethType(TYPE, SIG) \
namespace detail::op_meth_tag { \
struct TYPE { \
constexpr static char name[] = #TYPE; \
}; \
} \
using TYPE = detail::OpMeth<detail::op_meth_tag::TYPE, SIG>
OpMethType
(
OpDefMaker
,
decltype
(
OpDef
::
make_from_op_node
));
OpMethType
(
DecideDispatchMode
,
decltype
(
OpDef
::
decide_dispatch_mode
));
OpMethType
(
ApplyOnPhysicalTensor
,
decltype
(
OpDef
::
apply_on_physical_tensor
));
OpMethType
(
InferOutputMemDesc
,
decltype
(
OpDef
::
infer_output_mem_desc
));
OpMethType
(
Execute
,
decltype
(
OpDef
::
execute
));
OpMethType
(
ApplyOnDeviceTensorND
,
decltype
(
OpDef
::
apply_on_device_tensornd
));
OpMethType
(
ApplyOnVarNode
,
decltype
(
OpDef
::
apply_on_var_node
));
OpMethType
(
InferOutputAttrsFallible
,
decltype
(
OpDef
::
infer_output_attrs_fallible
));
OpMethType
(
GradMaker
,
decltype
(
OpDef
::
make_backward_graph
));
OpMethType
(
Props
,
decltype
(
OpDef
::
props
));
OpMethType
(
HashFunc
,
size_t
(
const
OpDef
&
));
OpMethType
(
IsSame
,
bool
(
const
OpDef
&
,
const
OpDef
&
));
OpMethType
(
MakeNameFunc
,
std
::
string
(
const
OpDef
&
));
// clang-format on
namespace
detail
{
struct
OpMethNotImpl
{
template
<
typename
Tag
,
typename
RType
,
typename
...
Args
>
static
void
impl
(
thin_function
<
RType
(
Args
...)
>&
func
,
Tag
)
{
func
=
[](
Args
...
args
)
->
RType
{
mgb_throw
(
MegBrainError
,
"%s was not implemented yet"
,
Tag
::
name
);
};
}
};
struct
OpMethFallback
:
public
OpMethNotImpl
{
using
OpMethNotImpl
::
impl
;
static
void
impl
(
ApplyOnPhysicalTensor
&
func
,
op_meth_tag
::
ApplyOnPhysicalTensor
);
static
void
impl
(
Execute
&
func
,
op_meth_tag
::
Execute
);
static
void
impl
(
InferOutputMemDesc
&
func
,
op_meth_tag
::
InferOutputMemDesc
);
static
void
impl
(
InferOutputAttrsFallible
&
func
,
op_meth_tag
::
InferOutputAttrsFallible
);
static
void
impl
(
GradMaker
&
func
,
op_meth_tag
::
GradMaker
);
static
void
impl
(
DecideDispatchMode
&
func
,
op_meth_tag
::
DecideDispatchMode
);
static
void
impl
(
MakeNameFunc
&
func
,
op_meth_tag
::
MakeNameFunc
);
};
template
<
typename
Tag
,
typename
RType
,
typename
...
Args
>
struct
OpMeth
<
Tag
,
RType
(
Args
...)
>
:
public
thin_function
<
RType
(
Args
...)
>
{
using
Base
=
thin_function
<
RType
(
Args
...)
>
;
using
Base
::
operator
bool
;
OpMeth
()
:
Base
{},
allow_fallback
(
false
){};
explicit
OpMeth
(
const
Base
&
base
)
{
this
->
Base
::
operator
=
(
base
);
}
RType
operator
()(
Args
...
args
)
const
{
if
(
!
this
->
Base
::
operator
bool
())
{
if
(
allow_fallback
)
{
OpMethFallback
::
impl
(
*
const_cast
<
OpMeth
*>
(
this
),
Tag
{});
}
else
{
OpMethNotImpl
::
impl
(
*
const_cast
<
OpMeth
*>
(
this
),
Tag
{});
}
}
return
this
->
Base
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
}
bool
allow_fallback
=
false
;
};
}
// namespace detail
struct
OpTrait
{
const
char
*
name
;
...
...
@@ -102,28 +158,31 @@ struct OpTrait {
static
void
for_each_trait
(
thin_function
<
void
(
OpTrait
&
)
>
visitor
);
};
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
// clang-format off
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_backward_graph)
\
cb(props)
\
cb(hash)
\
cb(is_same_st)
\
cb(make_name)
// clang-format on
struct
OpTraitRegistry
{
OpTrait
*
trait
;
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
trait->meth = f; \
return *this; \
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth)::Base f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, \
#meth); \
trait->meth.Base::operator=(f); \
return *this; \
}
FOR_EACH_OP_METH
(
DECL
)
#undef DECL
...
...
@@ -162,7 +221,7 @@ struct OpTraitRegistry {
}
};
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
#define OP_TRAIT_REG(name, ...) \
...
...
imperative/src/impl/tensor_sanity_check.cpp
浏览文件 @
20e8541b
...
...
@@ -80,26 +80,30 @@ void TensorSanityCheck::enable() {
OpTrait
::
for_each_trait
([
this
](
OpTrait
&
trait
)
{
auto
backup
=
std
::
make_unique
<
ApplyOnPhysicalTensor
>
(
std
::
move
(
trait
.
apply_on_physical_tensor
));
trait
.
apply_on_physical_tensor
=
[
this
,
backup
=
backup
.
get
()]
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
for
(
auto
&&
i
:
inputs
)
{
if
(
!
m_checker
->
check
(
i
))
{
mgb_throw
(
TensorChecksumCalc
::
Error
,
"tensor modified before exec %s"
,
print_op
(
def
).
c_str
());
}
}
auto
output
=
(
*
backup
)(
def
,
inputs
);
for
(
auto
&&
i
:
output
)
{
mgb_assert
(
m_checker
->
check
(
i
));
}
for
(
auto
&&
i
:
inputs
)
{
if
(
!
m_checker
->
check
(
i
))
{
mgb_throw
(
TensorChecksumCalc
::
Error
,
"tensor modified after exec %s"
,
print_op
(
def
).
c_str
());
}
}
return
output
;
};
trait
.
apply_on_physical_tensor
=
ApplyOnPhysicalTensor
(
[
this
,
backup
=
backup
.
get
()](
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
for
(
auto
&&
i
:
inputs
)
{
if
(
!
m_checker
->
check
(
i
))
{
mgb_throw
(
TensorChecksumCalc
::
Error
,
"tensor modified before exec %s"
,
print_op
(
def
).
c_str
());
}
}
auto
output
=
(
*
backup
)(
def
,
inputs
);
for
(
auto
&&
i
:
output
)
{
mgb_assert
(
m_checker
->
check
(
i
));
}
for
(
auto
&&
i
:
inputs
)
{
if
(
!
m_checker
->
check
(
i
))
{
mgb_throw
(
TensorChecksumCalc
::
Error
,
"tensor modified after exec %s"
,
print_op
(
def
).
c_str
());
}
}
return
output
;
});
m_checker
->
hook_list
.
push_back
({
&
trait
,
std
::
move
(
backup
)});
});
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录