Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fe3ee3cd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fe3ee3cd
编写于
7月 20, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(opr): refactor OprArityTrait
GitOrigin-RevId: fa065cde4ea223dcc394d2a73e14d978384f85de
上级
91efd67d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
71 addition
and
98 deletion
+71
-98
src/opr/impl/dnn/convolution.cpp
src/opr/impl/dnn/convolution.cpp
+71
-98
未找到文件。
src/opr/impl/dnn/convolution.cpp
浏览文件 @
fe3ee3cd
...
...
@@ -32,6 +32,7 @@ MIDOUT_DECL(megbrain_opr_convolution)
MIDOUT_END();
#include "../internal/megdnn_opr_wrapper.inl"
#include "../internal/invoke.h"
#include <array>
#include <chrono>
...
...
@@ -109,104 +110,74 @@ struct OprAttributeTrait<opr::ConvBias> {
}
};
template
<
typename
Opr
>
constexpr
bool
opr_supports_preprocess
()
{
return
std
::
is_same
<
Opr
,
megdnn
::
ConvolutionForward
>::
value
||
std
::
is_same
<
Opr
,
megdnn
::
ConvBias
>::
value
;
}
template
<
typename
Opr
>
struct
OprArityTrait
;
#define cb(x) (x)
#define cb_ref(x) (&(x))
#define cb_dnn(x) ((x).as_megdnn())
#define APPLY(statement, ...) \
mgb::apply([&](const auto&... args) { return statement; }, \
std::tuple_cat(__VA_ARGS__))
template
<
typename
Opr
,
int
_arity_in
,
int
_arity_out
>
struct
OprArityTraitTmpl
{
static
constexpr
int
arity_in
=
_arity_in
;
static
constexpr
int
arity_out
=
_arity_out
;
static
constexpr
int
arity
=
arity_in
+
arity_out
;
using
Algorithm
=
typename
Opr
::
Algorithm
;
using
TensorLayoutArray
=
std
::
array
<
TensorLayout
,
arity
>
;
static
size_t
get_workspace_in_bytes
(
Opr
*
opr
,
Algorithm
*
algo
,
const
TensorLayoutArray
&
layouts
)
{
opr
->
execution_policy
()
=
{
algo
};
size_t
workspace_size
;
if_constexpr
<
opr_supports_preprocess
<
Opr
>
()
>
([
&
](
auto
)
{
workspace_size
=
APPLY
(
opr
->
get_workspace_in_bytes
(
args
...,
nullptr
),
layouts
);
},
/* else */
[
&
](
auto
)
{
workspace_size
=
APPLY
(
opr
->
get_workspace_in_bytes
(
args
...),
layouts
);
});
return
workspace_size
;
}
#define WS_ARG_true ,nullptr
#define WS_ARG_false
#define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \
template <> \
struct OprArityTrait<_Opr> { \
static constexpr int arity_in = _in; \
static constexpr int arity_out = _out; \
static constexpr int arity = _in + _out; \
using TensorLayoutArray = std::array<TensorLayout, arity>; \
static size_t get_workspace_in_bytes( \
_Opr* opr, typename _Opr::Algorithm* algo, \
const TensorLayoutArray& layouts) { \
opr->execution_policy() = {algo}; \
return opr->get_workspace_in_bytes( \
LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \
} \
\
static std::vector<typename _Opr::Algorithm*> get_all_algorithms( \
_Opr* opr, const TensorLayoutArray& layouts) { \
return opr->get_all_algorithms(LAYOUTS(cb)); \
} \
\
static typename _Opr::Algorithm* get_algorithm_heuristic( \
_Opr* opr, const TensorLayoutArray& layouts, \
size_t workspace_limit, bool reproducible) { \
return opr->get_algorithm_heuristic(LAYOUTS(cb), workspace_limit, \
reproducible); \
} \
\
static void exec(_Opr* opr, const DeviceTensorND* inp_val, \
const DeviceTensorND* out_val, \
megdnn::Workspace& workspace) { \
opr->exec(TENSORS(cb_dnn), workspace); \
} \
static
void
exec
(
Opr
*
opr
,
const
std
::
array
<
DeviceTensorND
,
arity_in
>&
inp_val
,
const
std
::
array
<
DeviceTensorND
,
arity_out
>&
out_val
,
megdnn
::
Workspace
&
workspace
)
{
if_constexpr
<
opr_supports_preprocess
<
Opr
>
()
>
([
&
](
auto
)
{
APPLY
(
opr
->
exec
(
args
.
as_megdnn
()...,
nullptr
,
workspace
),
inp_val
,
out_val
);
},
/* else */
[
&
](
auto
)
{
APPLY
(
opr
->
exec
(
args
.
as_megdnn
()...,
workspace
),
inp_val
,
out_val
);
});
}
};
#define INST_ARITY(_Opr, _in, _out) \
template <> \
struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {};
INST_ARITY
(
megdnn
::
ConvolutionBackwardData
,
2
,
1
);
INST_ARITY
(
megdnn
::
ConvolutionBackwardFilter
,
2
,
1
);
INST_ARITY
(
megdnn
::
Convolution3DForward
,
2
,
1
);
INST_ARITY
(
megdnn
::
Convolution3DBackwardData
,
2
,
1
);
INST_ARITY
(
megdnn
::
Convolution3DBackwardFilter
,
2
,
1
);
INST_ARITY
(
megdnn
::
LocalShareForward
,
2
,
1
);
INST_ARITY
(
megdnn
::
LocalShareBackwardData
,
2
,
1
);
INST_ARITY
(
megdnn
::
LocalShareBackwardFilter
,
2
,
1
);
INST_ARITY
(
megdnn
::
Convolution
,
2
,
1
);
INST_ARITY
(
megdnn
::
DeformableConvForward
,
4
,
1
);
INST_ARITY
(
megdnn
::
DeformableConvBackwardFilter
,
4
,
1
);
INST_ARITY
(
megdnn
::
BatchConvBiasForward
,
4
,
1
);
INST_ARITY
(
megdnn
::
ConvBias
,
4
,
1
);
INST_ARITY
(
megdnn
::
DeformableConvBackwardData
,
5
,
3
);
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0])
#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2])
#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false)
INST_ARITY_2_1
(
megdnn
::
ConvolutionBackwardData
);
INST_ARITY_2_1
(
megdnn
::
ConvolutionBackwardFilter
);
INST_ARITY_2_1
(
megdnn
::
Convolution3DForward
);
INST_ARITY_2_1
(
megdnn
::
Convolution3DBackwardData
);
INST_ARITY_2_1
(
megdnn
::
Convolution3DBackwardFilter
);
INST_ARITY_2_1
(
megdnn
::
LocalShareForward
);
INST_ARITY_2_1
(
megdnn
::
LocalShareBackwardData
);
INST_ARITY_2_1
(
megdnn
::
LocalShareBackwardFilter
);
#undef TENSORS
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr
INST_ARITY
(
megdnn
::
Convolution
,
2
,
1
,
true
);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_2_1
#define TENSORS(cb) \
cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \
cb(out_val[0])
#define LAYOUTS(cb) \
cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \
cb(layouts[4])
#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false)
INST_ARITY_4_1
(
megdnn
::
DeformableConvForward
);
INST_ARITY_4_1
(
megdnn
::
DeformableConvBackwardFilter
);
INST_ARITY_4_1
(
megdnn
::
BatchConvBiasForward
);
#undef TENSORS
#define TENSORS(cb) \
cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \
cb(out_val[0]), nullptr
INST_ARITY
(
megdnn
::
ConvBias
,
4
,
1
,
true
);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_4_1
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), \
cb(inp_val[3]), cb(inp_val[4]), cb(out_val[0]), \
cb(out_val[1]), cb(out_val[2])
#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), \
cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \
cb(layouts[6]), cb(layouts[7])
#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false)
INST_ARITY_5_3
(
megdnn
::
DeformableConvBackwardData
);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_5_3
#undef cb
#undef cb_ref
#undef cb_dnn
#undef INST_ARITY
#undef WS_ARG_true
#undef WS_ARG_false
// timeout delta to be added with fastest known algorithm for new algos
constexpr
double
TIMEOUT_TOLERANCE
=
2
;
...
...
@@ -343,8 +314,7 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
megdnn_opr
->
param
()
=
param
.
opr_param
;
{
typename
Opr
::
Algorithm
*
algo
=
nullptr
;
for
(
auto
i
:
OprArityTrait
<
Opr
>::
get_all_algorithms
(
megdnn_opr
.
get
(),
layouts
))
{
for
(
auto
i
:
APPLY
(
megdnn_opr
->
get_all_algorithms
(
args
...),
layouts
))
{
if
(
!
strcmp
(
i
->
name
(),
param
.
algo_name
))
{
algo
=
i
;
break
;
...
...
@@ -368,7 +338,9 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
}
// allocate input and output memory
DeviceTensorND
inp_val
[
arity_in
],
out_val
[
arity_out
],
workspace
;
std
::
array
<
DeviceTensorND
,
arity_in
>
inp_val
;
std
::
array
<
DeviceTensorND
,
arity_out
>
out_val
;
DeviceTensorND
workspace
;
for
(
int
i
=
0
;
i
<
arity_in
;
++
i
)
{
inp_val
[
i
]
.
comp_node
(
cn
)
...
...
@@ -484,16 +456,17 @@ class AlgoChooser {
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
opr
->
owner_graph
(),
opr
->
comp_node
(),
opr
->
execution_policy
().
workspace_limit
);
return
OprArityTrait
<
Opr
>::
get_algorithm_heuristic
(
m_megdnn_opr
,
m_layouts
,
workspace_limit
,
reproducible
);
return
APPLY
(
m_megdnn_opr
->
get_algorithm_heuristic
(
args
...,
workspace_limit
,
reproducible
),
m_layouts
);
}
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std
::
vector
<
ImplAlgo
>
get_all_candidates
()
const
{
auto
heu
=
choose_by_heuristic
();
auto
&&
ret
=
OprArityTrait
<
Opr
>::
get_all_algorithms
(
m_megdnn_opr
,
m_layouts
);
auto
&&
ret
=
APPLY
(
m_megdnn_opr
->
get_all_algorithms
(
args
...),
m_layouts
);
bool
found
=
false
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
ret
[
i
]
==
heu
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录