Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
03d0cc02
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
03d0cc02
编写于
10月 20, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(gopt): add remove redundant copy pass
GitOrigin-RevId: f616a76d2968dc44bb16840715d45f44117a7ba1
上级
545567d3
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
147 addition
and
6 deletion
+147
-6
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+1
-0
src/gopt/impl/misc.cpp
src/gopt/impl/misc.cpp
+63
-0
src/gopt/include/megbrain/gopt/misc.h
src/gopt/include/megbrain/gopt/misc.h
+10
-0
src/gopt/test/helper.h
src/gopt/test/helper.h
+7
-6
src/gopt/test/misc.cpp
src/gopt/test/misc.cpp
+66
-0
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
03d0cc02
...
...
@@ -649,6 +649,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_pass
<
ReorderArithChainPass
>
(
cv_type
);
add_pass
<
FinalArithTransformPass
>
();
add_pass
<
RemoveRedundantTypeCvtPass
>
();
add_pass
<
RemoveRedundantCopyPass
>
();
#if MGB_JIT
bool
need_jit
=
false
;
...
...
src/gopt/impl/misc.cpp
浏览文件 @
03d0cc02
...
...
@@ -682,6 +682,69 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {
MIDOUT_E
}
/* ======================= RemoveRedundantCopyPass ====================== */
const
char
*
RemoveRedundantCopyPass
::
name
()
const
{
return
"remove_redundant_copy"
;
}
bool
RemoveRedundantCopyPass
::
should_remove
(
const
CompNode
&
A
,
const
CompNode
&
B
)
{
//! if A and B has the same memnode and cpu <-> atlas/cpu <-> cuda, as only
//! these two compnode support crosscncopy
if
(
A
.
mem_node
()
==
B
.
mem_node
()
||
((
A
.
device_type
()
==
CompNode
::
DeviceType
::
CPU
||
A
.
device_type
()
==
CompNode
::
DeviceType
::
MULTITHREAD
)
&&
(
B
.
device_type
()
==
CompNode
::
DeviceType
::
ATLAS
||
B
.
device_type
()
==
CompNode
::
DeviceType
::
CUDA
))
||
((
B
.
device_type
()
==
CompNode
::
DeviceType
::
CPU
||
B
.
device_type
()
==
CompNode
::
DeviceType
::
MULTITHREAD
)
&&
(
A
.
device_type
()
==
CompNode
::
DeviceType
::
ATLAS
||
A
.
device_type
()
==
CompNode
::
DeviceType
::
CUDA
)))
{
return
true
;
}
else
{
return
false
;
}
}
void
RemoveRedundantCopyPass
::
apply
(
OptState
&
opt
)
const
{
MIDOUT_B
(
"RemoveRedundantCopyPass::apply"
)
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
](
OperatorNodeBase
*
opr
)
{
if
(
auto
copy0
=
try_cast_as_op
<
opr
::
Copy
>
(
opr
))
{
auto
inp0
=
rewriter
.
get_var
(
copy0
->
input
(
0
));
if
(
auto
copy1
=
try_cast_as_op
<
opr
::
Copy
>
(
inp0
))
{
auto
inp1
=
copy1
->
input
(
0
);
if
(
should_remove
(
inp1
->
comp_node
(),
copy0
->
output
(
0
)
->
comp_node
()))
{
mgb_assert
(
!
rewriter
.
has_manual_replace
(
inp1
));
if
(
inp1
->
comp_node
()
==
copy0
->
output
(
0
)
->
comp_node
())
{
rewriter
.
replace_var
(
copy0
->
output
(
0
),
inp1
,
mgb_cstr_log
(
"copy(copy(a0, a1), a0) -> "
"a0"
));
return
;
}
else
{
auto
fold
=
opr
::
Copy
::
make
(
inp1
,
copy0
->
output
(
0
)
->
comp_node
());
rewriter
.
replace_var
(
copy0
->
output
(
0
),
fold
.
node
(),
mgb_cstr_log
(
"copy(copy(a0, a1), a2) -> "
"copy(a0, a2)"
));
return
;
}
}
}
}
rewriter
.
auto_replace_outputs
(
opr
);
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
MIDOUT_E
}
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
...
...
src/gopt/include/megbrain/gopt/misc.h
浏览文件 @
03d0cc02
...
...
@@ -85,6 +85,16 @@ namespace gopt {
void
apply
(
OptState
&
opt
)
const
override
;
};
class
RemoveRedundantCopyPass
final
:
public
Pass
{
private:
//! Remove the copy chain of form cpu -> cpu -> cpu,
//! cpu -> gpu -> cpu
static
bool
should_remove
(
const
CompNode
&
A
,
const
CompNode
&
B
);
public:
const
char
*
name
()
const
override
;
void
apply
(
OptState
&
opt
)
const
override
;
};
//! remove execution mask for const PPVs in conditional execution
class
CondExecConstPredicateFolding
final
:
public
Pass
{
public:
...
...
src/gopt/test/helper.h
浏览文件 @
03d0cc02
...
...
@@ -26,14 +26,16 @@ namespace mgb {
HostTensorGenerator
<>
gen
;
std
::
shared_ptr
<
ComputingGraph
>
graph
=
ComputingGraph
::
make
();
SymbolVar
mkvar
(
const
char
*
name
,
const
TensorShape
&
shp
=
{
1
})
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
)).
rename
(
name
);
SymbolVar
mkvar
(
const
char
*
name
,
const
TensorShape
&
shp
=
{
1
},
CompNode
cn
=
CompNode
::
load
(
"xpu0"
))
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
),
cn
)
.
rename
(
name
);
}
SymbolVar
mkcvar
(
const
char
*
name
,
const
TensorShape
&
shp
=
{
1
})
{
SymbolVar
mkcvar
(
const
char
*
name
,
const
TensorShape
&
shp
=
{
1
},
CompNode
cn
=
CompNode
::
load
(
"xpu0"
))
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
)).
rename
(
name
);
*
graph
,
*
gen
(
shp
)
,
cn
).
rename
(
name
);
}
template
<
typename
...
Args
>
...
...
@@ -73,4 +75,3 @@ namespace mgb {
TEST_F(TestGopt##pass, name)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/test/misc.cpp
浏览文件 @
03d0cc02
...
...
@@ -16,6 +16,7 @@
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/cond.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
...
...
@@ -411,6 +412,71 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) {
check
(
x_q8_q8
,
x_q8_fp32_q8_
);
}
TEST_PASS
(
RemoveRedundantCopyPass
,
Basic
)
{
auto
x
=
mkvar
(
"x"
,
{
2
,
3
,
3
},
CompNode
::
load
(
"cpu0"
));
{
auto
x_cpu1
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu1"
));
auto
x_cpu0
=
opr
::
Copy
::
make
(
x_cpu1
,
CompNode
::
load
(
"cpu0"
));
auto
x_cpu2
=
opr
::
Copy
::
make
(
x_cpu0
,
CompNode
::
load
(
"cpu2"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu2"
));
check
(
x
,
x_cpu0
);
check
(
x_expected
,
x_cpu2
);
}
{
auto
x_cpu1
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu1"
));
auto
x_cpu2
=
opr
::
Copy
::
make
(
x_cpu1
,
CompNode
::
load
(
"cpu2"
));
auto
x_cpu3
=
opr
::
Copy
::
make
(
x_cpu2
,
CompNode
::
load
(
"cpu3"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu3"
));
check
(
x_expected
,
x_cpu3
);
}
{
auto
x_cpu1
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu0:1"
));
auto
x_cpu2
=
opr
::
Copy
::
make
(
x_cpu1
,
CompNode
::
load
(
"cpu0:2"
));
auto
x_cpu3
=
opr
::
Copy
::
make
(
x_cpu2
,
CompNode
::
load
(
"cpu0:3"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu0:3"
));
check
(
x_expected
,
x_cpu3
);
}
{
auto
x_cpu1
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu0:1"
));
auto
x_mt
=
opr
::
Copy
::
make
(
x_cpu1
,
CompNode
::
load
(
"multithread8:0"
));
auto
x_cpu3
=
opr
::
Copy
::
make
(
x_mt
,
CompNode
::
load
(
"cpu0:3"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu0:3"
));
check
(
x_expected
,
x_cpu3
);
}
#if MGB_ATLAS
{
auto
x_atlas0
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"atlas0"
));
auto
x_cpu2
=
opr
::
Copy
::
make
(
x_atlas0
,
CompNode
::
load
(
"cpu0:2"
));
auto
x_cpu3
=
opr
::
Copy
::
make
(
x_cpu2
,
CompNode
::
load
(
"cpu0:3"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu0:3"
));
check
(
x_expected
,
x_cpu3
);
}
#endif
#if MGB_CUDA
{
auto
x_cuda0
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"gpu0"
));
auto
x_cpu2
=
opr
::
Copy
::
make
(
x_cuda0
,
CompNode
::
load
(
"cpu0:2"
));
auto
x_cpu3
=
opr
::
Copy
::
make
(
x_cpu2
,
CompNode
::
load
(
"cpu0:3"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"cpu0:3"
));
check
(
x_expected
,
x_cpu3
);
}
{
auto
x_mt
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"multithread8:0"
));
auto
x_cpu2
=
opr
::
Copy
::
make
(
x_mt
,
CompNode
::
load
(
"gpu0:1"
));
auto
x_cpu3
=
opr
::
Copy
::
make
(
x_cpu2
,
CompNode
::
load
(
"multithread8:0"
));
auto
x_expected
=
opr
::
Copy
::
make
(
x
,
CompNode
::
load
(
"multithread8:0"
));
check
(
x_expected
,
x_cpu3
);
}
#endif
}
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
#include "../../opr-mm/test/mock_client.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录