Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8f7f52ae
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8f7f52ae
编写于
12月 30, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(jit): add memfwd in jit executor opr
GitOrigin-RevId: b58860bbe87582023d96fdfe9d2cb5c6c93b8731
上级
dfb2b2ce
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
66 addition
and
29 deletion
+66
-29
dnn/src/common/pooling.cpp
dnn/src/common/pooling.cpp
+1
-1
src/jit/impl/executor_opr.cpp
src/jit/impl/executor_opr.cpp
+7
-0
src/jit/include/megbrain/jit/executor_opr.h
src/jit/include/megbrain/jit/executor_opr.h
+5
-1
src/jit/test/fusion.cpp
src/jit/test/fusion.cpp
+6
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+1
-26
src/opr/impl/internal/identical_fwd.cpp
src/opr/impl/internal/identical_fwd.cpp
+31
-0
src/opr/include/megbrain/opr/basic_arith.h
src/opr/include/megbrain/opr/basic_arith.h
+2
-1
src/opr/include/megbrain/opr/internal/identical_fwd.h
src/opr/include/megbrain/opr/internal/identical_fwd.h
+13
-0
未找到文件。
dnn/src/common/pooling.cpp
浏览文件 @
8f7f52ae
...
...
@@ -92,7 +92,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
size_t
sw
=
this
->
param
().
stride_w
;
size_t
ph
=
this
->
param
().
pad_h
;
size_t
pw
=
this
->
param
().
pad_w
;
if
(
ph
<
fh
&&
pw
<
fw
)
{
if
(
ph
>=
fh
||
pw
>=
fw
)
{
megdnn_log_error
(
"pooling padding size (%zu %zu) should not be bigger than "
"window size (%zu %zu), it only can be used in CaffePooling"
,
...
...
src/jit/impl/executor_opr.cpp
浏览文件 @
8f7f52ae
...
...
@@ -135,6 +135,13 @@ void JITExecutor::init_output_mem_plan(bool dynamic) {
m_args
.
need_update
=
true
;
}
void
JITExecutor
::
mem_plan_fwd_in2out_writable
()
{
//! currently mem fwd only support elemwise fusion
if
(
m_feature_bits
!=
JITFeatureBits
::
NONE
)
return
;
mixin_mem_plan_fwd_in2out_writable
(
*
this
);
}
SymbolVar
JITExecutor
::
make
(
const
InternalGraphPtr
&
internal_graph
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
...
...
src/jit/include/megbrain/jit/executor_opr.h
浏览文件 @
8f7f52ae
...
...
@@ -13,6 +13,7 @@
#include "megbrain/graph/operator_node.h"
#include "megbrain/jit/internal_graph.h"
#include "megbrain/opr/internal/identical_fwd.h"
#if MGB_JIT
...
...
@@ -31,7 +32,8 @@ class Compiler;
* JITExecutor generates runtime Args for this specific inputs, and calls
* methods in Compiler to get the Executable object for actual computing.
*/
MGB_DEFINE_OPR_CLASS
(
JITExecutor
,
cg
::
SingleCNOperatorNodeBase
)
// {
MGB_DEFINE_OPR_CLASS
(
JITExecutor
,
cg
::
SingleCNOperatorNodeBase
,
opr
::
mixin
::
FwdIn2OutWritableHelper
)
// {
using
ModeTrait
=
megdnn
::
Elemwise
::
ModeTrait
;
InternalGraphPtr
m_internal_graph
;
...
...
@@ -57,6 +59,8 @@ public:
void
init_output_mem_plan
(
bool
dynamic
)
override
;
void
mem_plan_fwd_in2out_writable
()
override
;
const
InternalGraph
&
internal_graph
()
const
{
return
*
m_internal_graph
;
}
const
InternalGraphPtr
internal_graph_ptr
()
const
{
...
...
src/jit/test/fusion.cpp
浏览文件 @
8f7f52ae
...
...
@@ -137,6 +137,12 @@ void run<basic>(Backend backend, CompNode cn) {
// only one broadcast is allowed in JIT fusion
ASSERT_EQ
(
1u
,
jits
[
0
]
->
input
().
size
());
ASSERT_EQ
(
4u
,
jits
[
1
]
->
input
().
size
());
//! check memfwd
ASSERT_EQ
(
prev_dev_ptr
(
jits
[
0
]
->
input
(
0
)),
prev_dev_ptr
(
jits
[
0
]
->
output
(
0
)));
ASSERT_EQ
(
prev_dev_ptr
(
jits
[
1
]
->
input
(
0
)),
prev_dev_ptr
(
jits
[
1
]
->
output
(
0
)));
}
template
<
>
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
8f7f52ae
...
...
@@ -338,32 +338,7 @@ void Elemwise::broadcast_collective_collapse(
}
void
Elemwise
::
mem_plan_fwd_in2out_writable
()
{
auto
&&
inp
=
input
();
auto
isize
=
inp
.
size
();
mgb_assert
(
isize
<=
6
);
bool
have_conflict
[
6
]
=
{
false
};
for
(
size_t
i
=
0
;
i
<
isize
;
++
i
)
{
for
(
size_t
j
=
i
+
1
;
j
<
isize
;
++
j
)
{
auto
type
=
cg
::
get_mem_plan_intersection_type
(
inp
[
i
],
inp
[
j
]);
using
Type
=
cg
::
MemPlanIntersectionType
;
bool
overlap
=
type
==
Type
::
OVERLAP
;
bool
self_fwd
=
type
==
Type
::
IDENTICAL
&&
(
!
inp
[
i
]
->
layout
().
is_contiguous
()
||
!
inp
[
j
]
->
layout
().
is_contiguous
());
if
(
overlap
||
self_fwd
)
{
have_conflict
[
i
]
=
true
;
have_conflict
[
j
]
=
true
;
}
}
}
auto
o
=
output
(
0
);
for
(
size_t
idx
=
0
;
idx
<
isize
;
++
idx
)
{
auto
i
=
inp
[
idx
];
// equal shape means no broadcast
if
(
!
have_conflict
[
idx
]
&&
o
->
shape
().
eq_shape
(
i
->
shape
())
&&
i
->
layout
().
is_contiguous
())
o
->
set_fwd_in2out_writable
(
i
);
}
mixin_mem_plan_fwd_in2out_writable
(
*
this
);
}
void
Elemwise
::
scn_do_execute
()
{
...
...
src/opr/impl/internal/identical_fwd.cpp
浏览文件 @
8f7f52ae
...
...
@@ -33,6 +33,37 @@ void mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o(
valid_out
->
add_rt_force_dynamic_mem_alloc_imply_chain
(
opr
.
input
(
0
));
}
/* ===================== FwdIn2OutWritableHelper ===================== */
void
FwdIn2OutWritableHelper
::
mixin_mem_plan_fwd_in2out_writable
(
OperatorNodeBase
&
opr
)
{
auto
&&
inp
=
opr
.
input
();
auto
isize
=
inp
.
size
();
std
::
vector
<
bool
>
have_conflict
(
isize
,
false
);
for
(
size_t
i
=
0
;
i
<
isize
;
++
i
)
{
for
(
size_t
j
=
i
+
1
;
j
<
isize
;
++
j
)
{
auto
type
=
cg
::
get_mem_plan_intersection_type
(
inp
[
i
],
inp
[
j
]);
using
Type
=
cg
::
MemPlanIntersectionType
;
bool
overlap
=
type
==
Type
::
OVERLAP
;
bool
self_fwd
=
type
==
Type
::
IDENTICAL
&&
(
!
inp
[
i
]
->
layout
().
is_contiguous
()
||
!
inp
[
j
]
->
layout
().
is_contiguous
());
if
(
overlap
||
self_fwd
)
{
have_conflict
[
i
]
=
true
;
have_conflict
[
j
]
=
true
;
}
}
}
auto
o
=
opr
.
output
(
0
);
for
(
size_t
idx
=
0
;
idx
<
isize
;
++
idx
)
{
auto
i
=
inp
[
idx
];
// equal shape means no broadcast
if
(
!
have_conflict
[
idx
]
&&
o
->
shape
().
eq_shape
(
i
->
shape
())
&&
o
->
dtype
().
enumv
()
==
i
->
dtype
().
enumv
()
&&
i
->
layout
().
is_contiguous
())
o
->
set_fwd_in2out_writable
(
i
);
}
}
/* ===================== ReadonlyFwdHelper ===================== */
void
ReadonlyFwdHelper
::
mixin_rofwd_init_mem_plan
(
OperatorNodeBase
&
opr
)
{
...
...
src/opr/include/megbrain/opr/basic_arith.h
浏览文件 @
8f7f52ae
...
...
@@ -58,7 +58,8 @@ namespace intl {
* The operands are broadcasted automatically on dimensions of shape one to
* match shapes of each other; it works like broadcasting in numpy.
*/
MGB_DEFINE_OPR_CLASS
(
Elemwise
,
intl
::
ElemwiseBase
)
// {
MGB_DEFINE_OPR_CLASS
(
Elemwise
,
intl
::
ElemwiseBase
,
mixin
::
FwdIn2OutWritableHelper
)
// {
using
ModeTrait
=
megdnn
::
Elemwise
::
ModeTrait
;
public:
...
...
src/opr/include/megbrain/opr/internal/identical_fwd.h
浏览文件 @
8f7f52ae
...
...
@@ -19,6 +19,19 @@ namespace opr {
namespace
mixin
{
/*!
* \brief mixin for operators which essentially works by forward input to output
*/
class
FwdIn2OutWritableHelper
:
public
cg
::
OperatorNodeMixinBase
{
protected:
/*!
* \brief call this function in mem_plan_fwd_in2out_writable(),
* this function will check if the input have conflict to find if the
* output can be forward.
*/
void
mixin_mem_plan_fwd_in2out_writable
(
OperatorNodeBase
&
opr
);
};
//! for internal use by DynamicOutputIfInputDynamic
void
init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o
(
OperatorNodeBase
&
opr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录