Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2b2c1730
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2b2c1730
编写于
8月 08, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 08, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3696 Update AdjustAllReduce to use Pattern Matcher
Merge pull request !3696 from Giancarlo/update_adjust_allreduce
上级
ea8b3c5d
e6a76e4d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
85 addition
and
87 deletion
+85
-87
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
...re/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
+30
-72
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h
...ore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h
+2
-8
mindspore/core/ir/pattern_matcher.h
mindspore/core/ir/pattern_matcher.h
+53
-7
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
浏览文件 @
2b2c1730
...
...
@@ -95,37 +95,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
AnfNodePtr
AdjustAllReduceMulAdd
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
Reset
();
// {prim::kPrimAddN, Zs}
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimAddN
))
{
return
nullptr
;
}
auto
addn
=
node
->
cast
<
CNodePtr
>
();
if
(
addn
->
size
()
!=
2
)
{
return
nullptr
;
}
AnfVisitor
::
Match
(
prim
::
kPrimMakeTuple
,
{
IsNode
,
IsNode
})(
addn
->
input
(
1
));
if
(
x_
==
nullptr
||
y_
==
nullptr
||
z_
==
nullptr
||
all_reduce_fg_
==
nullptr
)
{
return
nullptr
;
}
auto
addn_maketuple
=
addn
->
input
(
1
);
auto
fg
=
all_reduce_fg_
;
// addn inputs cross the graph, make the inputs same as allreduce node.
if
(
z_
->
isa
<
CNode
>
()
&&
fg
!=
z_
->
func_graph
())
{
auto
cnode_z
=
z_
->
cast
<
CNodePtr
>
();
z_
=
NewCNode
(
cnode_z
->
inputs
(),
fg
);
}
auto
addn_op_node
=
addn
->
input
(
0
);
auto
make_tuple_op_node
=
addn
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
0
);
PatternNode
x
,
y
,
z
;
auto
all_reduce_pat
=
PPrimitive
(
prim
::
kPrimAllReduce
,
x
);
auto
mul_pat
=
PBinOperation
(
prim
::
kPrimMul
,
all_reduce_pat
,
y
,
true
);
auto
admktup_pat
=
PBinOperation
(
prim
::
kPrimMakeTuple
,
mul_pat
,
z
,
true
);
auto
addn_pat
=
PPrimitive
(
prim
::
kPrimAddN
,
admktup_pat
);
auto
adjust_lambda
=
[
&
node
,
&
x
,
&
y
,
&
z
,
&
addn_pat
,
&
all_reduce_pat
,
&
admktup_pat
,
&
mul_pat
,
this
]()
->
AnfNodePtr
{
auto
fg
=
all_reduce_pat
.
GetFuncGraph
();
auto
z_
=
z
.
GetNode
(
node
);
// If addn inputs cross the graph, make the inputs same as allreduce node.
if
(
z_
->
isa
<
CNode
>
()
&&
fg
!=
z_
->
func_graph
())
{
auto
cnode_z
=
z_
->
cast
<
CNodePtr
>
();
z_
=
NewCNode
(
cnode_z
->
inputs
(),
fg
);
}
AnfNodePtr
tuple
=
NewCNode
({
make_tuple_op_node
,
z_
,
x_
},
fg
);
AnfNodePtr
add
=
NewCNode
({
addn_op_node
,
tuple
},
fg
);
AnfNodePtr
all_reduce
=
NewCNode
({
all_reduce_
,
add
},
fg
);
AnfNodePtr
mul
=
NewCNode
({
mul_
,
all_reduce
,
y_
},
fg
);
ProcessDependEdge
(
fg
,
addn_maketuple
,
all_reduce
);
return
mul
;
auto
addn_cnode
=
addn_pat
.
GetOriginalNode
()
->
cast
<
CNodePtr
>
();
auto
addn_op_node
=
addn_cnode
->
input
(
0
);
auto
make_tuple_op_node
=
addn_cnode
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
0
);
auto
all_reduce_prim
=
all_reduce_pat
.
GetOriginalNode
()
->
cast
<
CNodePtr
>
()
->
input
(
0
);
mul_cnode_
=
mul_pat
.
GetOriginalNode
();
auto
mul_prim
=
mul_cnode_
->
cast
<
CNodePtr
>
()
->
input
(
0
);
auto
addn_maketuple
=
admktup_pat
.
GetOriginalNode
();
AnfNodePtr
tuple
=
NewCNode
({
make_tuple_op_node
,
z_
,
x
.
GetNode
(
node
)},
fg
);
AnfNodePtr
add
=
NewCNode
({
addn_op_node
,
tuple
},
fg
);
AnfNodePtr
all_reduce
=
NewCNode
({
all_reduce_prim
,
add
},
fg
);
AnfNodePtr
mul
=
NewCNode
({
mul_prim
,
all_reduce
,
y
.
GetNode
(
node
)},
fg
);
ProcessDependEdge
(
fg
,
addn_maketuple
,
all_reduce
);
return
mul
;
};
MATCH_REPLACE_LAMBDA
(
node
,
addn_pat
,
adjust_lambda
);
return
nullptr
;
}
void
AdjustAllReduceMulAdd
::
ProcessDependEdge
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
addn_maketuple
,
...
...
@@ -146,48 +146,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN
}
}
void
AdjustAllReduceMulAdd
::
Visit
(
const
AnfNodePtr
&
node
)
{
if
(
level_
==
0
)
{
level_
=
1
;
is_reduce_match_
=
false
;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
level_
=
0
;
if
(
is_reduce_match_
)
{
mul_
=
node
->
cast
<
CNodePtr
>
()
->
input
(
0
);
mul_cnode_
=
node
->
cast
<
CNodePtr
>
();
y_
=
tmp_
;
}
else
{
z_
=
node
;
}
}
if
(
level_
==
1
)
{
// {prim::kPrimAllReduce, X}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimAllReduce
))
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
->
size
()
>
1
)
{
all_reduce_
=
cnode
->
input
(
0
);
x_
=
cnode
->
input
(
1
);
is_reduce_match_
=
true
;
all_reduce_fg_
=
cnode
->
func_graph
();
}
}
else
{
tmp_
=
node
;
}
}
}
void
AdjustAllReduceMulAdd
::
Reset
()
{
level_
=
0
;
is_reduce_match_
=
false
;
x_
=
nullptr
;
y_
=
nullptr
;
z_
=
nullptr
;
tmp_
=
nullptr
;
all_reduce_fg_
=
nullptr
;
}
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
2b2c1730
...
...
@@ -38,20 +38,14 @@ namespace irpass {
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class
AdjustAllReduceMulAdd
:
public
AnfVisito
r
{
class
AdjustAllReduceMulAdd
:
public
OptimizerCalle
r
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
;
void
ProcessDependEdge
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
addn_maketuple
,
const
AnfNodePtr
&
new_node
);
void
Visit
(
const
AnfNodePtr
&
node
)
override
;
void
Reset
();
private:
int
level_
{
0
};
bool
is_reduce_match_
{
false
};
AnfNodePtr
x_
{
nullptr
},
y_
{
nullptr
},
z_
{
nullptr
},
tmp_
{
nullptr
};
AnfNodePtr
all_reduce_
{
nullptr
},
mul_
{
nullptr
},
mul_cnode_
{
nullptr
};
FuncGraphPtr
all_reduce_fg_
{
nullptr
};
AnfNodePtr
mul_cnode_
{
nullptr
};
};
class
ArithmeticSimplify
:
public
OptimizerCaller
{
...
...
mindspore/core/ir/pattern_matcher.h
浏览文件 @
2b2c1730
...
...
@@ -94,8 +94,8 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
~
PBinOperation
()
=
default
;
AnfNodePtr
GetNode
(
const
AnfNodePtr
&
node
)
const
{
AnfNodePtr
lhs
=
x_
.
GetNode
(
node
->
func_graph
()
);
AnfNodePtr
rhs
=
y_
.
GetNode
(
node
->
func_graph
()
);
AnfNodePtr
lhs
=
x_
.
GetNode
(
node
);
AnfNodePtr
rhs
=
y_
.
GetNode
(
node
);
AnfNodePtrList
list
=
{
NewValueNode
(
prim_
),
lhs
,
rhs
};
return
NewCNode
(
list
,
node
->
func_graph
());
}
...
...
@@ -113,25 +113,42 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
if
(
!
x_
.
TryCapture
(
inputs
[
2
])
||
!
y_
.
TryCapture
(
inputs
[
1
]))
{
return
false
;
}
captured_binop_node_
=
node
;
return
true
;
}
return
false
;
}
captured_binop_node_
=
node
;
return
true
;
}
}
return
false
;
}
/// Returns the original node captured by this Binary Operation Pattern.
/// Throws exception if a node was not captured before.
AnfNodePtr
GetOriginalNode
()
const
{
if
(
captured_binop_node_
==
nullptr
)
{
MS_EXCEPTION
(
ValueError
)
<<
"A Node wasn't captured for this Pattern before attempting to get it."
;
}
return
captured_binop_node_
;
}
void
Reset
()
const
{
x_
.
Reset
();
y_
.
Reset
();
captured_binop_node_
=
nullptr
;
}
using
Internal
=
const
PBinOperation
<
T
,
T2
>
&
;
private:
const
PrimitivePtr
prim_
;
typename
T
::
Internal
x_
;
typename
T2
::
Internal
y_
;
bool
is_commutative_
{
false
};
mutable
AnfNodePtr
captured_binop_node_
{
nullptr
};
};
///
...
...
@@ -265,10 +282,11 @@ class PCNode : public PBase<PCNode<TArgs...> > {
return
*
this
;
}
using
Internal
=
const
PCNode
<
TArgs
...
>
&
;
void
Reset
()
const
{
tuple_utils
::
PTupleResetCapture
reset
;
tuple_utils
::
apply_func_tuple
(
&
reset
,
args_
);
has_min_extra_nodes_
=
false
;
extra_nodes_
.
clear
();
}
...
...
@@ -316,6 +334,9 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
AnfNodePtrList
tokens
(
inputs
.
begin
()
+
1
,
inputs
.
end
());
tuple_utils
::
PTupleCapture
capture_func
(
tokens
);
tuple_utils
::
apply_func_tuple
(
&
capture_func
,
args_
);
if
(
capture_func
.
captured_
)
{
captured_prim_node_
=
node
;
}
return
capture_func
.
captured_
;
}
return
false
;
...
...
@@ -329,9 +350,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
tuple_utils
::
apply_func_tuple
(
&
capture_func
,
args_
);
// If it could capture the initial set of nodes specified in the Pattern
// and there are enough extra inputs to add
if
(
capture_func
.
captured_
&&
inputs
.
size
()
>
pattern_arg_len
+
1
)
{
extra_nodes_
.
insert
(
extra_nodes_
.
end
(),
inputs
.
begin
()
+
1
+
pattern_arg_len
,
inputs
.
end
());
return
true
;
if
(
capture_func
.
captured_
)
{
captured_prim_node_
=
node
;
if
(
inputs
.
size
()
>
pattern_arg_len
+
1
)
{
extra_nodes_
.
insert
(
extra_nodes_
.
end
(),
inputs
.
begin
()
+
1
+
pattern_arg_len
,
inputs
.
end
());
}
}
return
capture_func
.
captured_
;
}
...
...
@@ -349,19 +372,42 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
return
*
this
;
}
/// Returns the FuncGraph of the original node captured by this Primitive Pattern.
/// Throws exception if a node was not captured before.
FuncGraphPtr
GetFuncGraph
()
const
{
if
(
captured_prim_node_
==
nullptr
)
{
MS_EXCEPTION
(
ValueError
)
<<
"A Node wasn't captured for this Pattern before attempting to get its FuncGraph."
;
}
return
captured_prim_node_
->
func_graph
();
}
/// Returns the original node captured by this Primitive Pattern.
/// Throws exception if a node was not captured before.
AnfNodePtr
GetOriginalNode
()
const
{
if
(
captured_prim_node_
==
nullptr
)
{
MS_EXCEPTION
(
ValueError
)
<<
"A Node wasn't captured for this Pattern before attempting to get it."
;
}
return
captured_prim_node_
;
}
void
Reset
()
const
{
tuple_utils
::
PTupleResetCapture
reset
;
tuple_utils
::
apply_func_tuple
(
&
reset
,
args_
);
has_min_extra_nodes_
=
false
;
extra_nodes_
.
clear
();
captured_prim_node_
=
nullptr
;
}
using
Internal
=
const
PPrimitive
<
TArgs
...
>
&
;
private:
const
PrimitivePtr
prim_
;
std
::
tuple
<
typename
TArgs
::
Internal
...
>
args_
;
mutable
AnfNodePtrList
extra_nodes_
;
mutable
bool
has_min_extra_nodes_
{
false
};
mutable
size_t
min_extra_nodes_
{
0
};
mutable
AnfNodePtr
captured_prim_node_
{
nullptr
};
};
///
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录