Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
81833943
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
81833943
编写于
8月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2969 Eliminate AllReduce when the input is a constant in Graph mode
Merge pull request !2969 from BowenK/fix_reduce_all
上级
60551b1f
f3a9fbdd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
57 addition
and
0 deletion
+57
-0
mindspore/ccsrc/frontend/optimizer/irpass.cc
mindspore/ccsrc/frontend/optimizer/irpass.cc
+2
-0
mindspore/ccsrc/frontend/optimizer/irpass.h
mindspore/ccsrc/frontend/optimizer/irpass.h
+1
-0
mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h
...re/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h
+53
-0
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+1
-0
未找到文件。
mindspore/ccsrc/frontend/optimizer/irpass.cc
浏览文件 @
81833943
...
@@ -83,6 +83,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
...
@@ -83,6 +83,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
reset_defer_inline_
=
reset_defer_inline_
=
MakeSubstitution
(
std
::
make_shared
<
ResetDeferInline
>
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
MakeSubstitution
(
std
::
make_shared
<
ResetDeferInline
>
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
depend_value_elim_
=
MakeSubstitution
(
std
::
make_shared
<
DependValueElim
>
(),
"depend_value_elim"
,
prim
::
kPrimDepend
);
depend_value_elim_
=
MakeSubstitution
(
std
::
make_shared
<
DependValueElim
>
(),
"depend_value_elim"
,
prim
::
kPrimDepend
);
all_reduce_const_elim_
=
MakeSubstitution
(
std
::
make_shared
<
AllReduceConstElim
>
(),
"reduce_all_const_elim"
,
prim
::
kPrimAllReduce
);
// Env Item Eliminate
// Env Item Eliminate
env_get_item_eliminate_
=
env_get_item_eliminate_
=
...
...
mindspore/ccsrc/frontend/optimizer/irpass.h
浏览文件 @
81833943
...
@@ -50,6 +50,7 @@ class OptimizeIRPassLib {
...
@@ -50,6 +50,7 @@ class OptimizeIRPassLib {
SubstitutionPtr
check_bprop_eliminate_
;
SubstitutionPtr
check_bprop_eliminate_
;
SubstitutionPtr
reset_defer_inline_
;
SubstitutionPtr
reset_defer_inline_
;
SubstitutionPtr
depend_value_elim_
;
SubstitutionPtr
depend_value_elim_
;
SubstitutionPtr
all_reduce_const_elim_
;
// Env Item Eliminate
// Env Item Eliminate
SubstitutionPtr
env_get_item_eliminate_
;
SubstitutionPtr
env_get_item_eliminate_
;
...
...
mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h
浏览文件 @
81833943
...
@@ -29,6 +29,8 @@
...
@@ -29,6 +29,8 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/optimizer.h"
#include "utils/comm_manager.h"
#include "frontend/parallel/context.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
opt
{
namespace
opt
{
...
@@ -203,6 +205,57 @@ class DependValueElim : public OptimizerCaller {
...
@@ -203,6 +205,57 @@ class DependValueElim : public OptimizerCaller {
return
nullptr
;
return
nullptr
;
}
}
};
};
class
AllReduceConstElim
:
public
OptimizerCaller
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
PatternNode
<
AnfNodePtr
>
x
;
auto
pattern
=
PPrimitive
(
prim
::
kPrimAllReduce
,
x
);
// If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode)
if
(
pattern
.
TryCapture
(
node
)
&&
IsVNode
(
x
.
GetNode
(
node
))
&&
(
pattern
.
GetFuncGraph
()
->
has_flag
(
parallel
::
AUTO_PARALLEL
)
||
pattern
.
GetFuncGraph
()
->
has_flag
(
parallel
::
SEMI_AUTO_PARALLEL
)))
{
auto
cur_func_graph
=
pattern
.
GetFuncGraph
();
// If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant
auto
prim_cnode
=
pattern
.
GetOriginalNode
();
MS_EXCEPTION_IF_NULL
(
prim_cnode
);
auto
primitive
=
GetCNodePrimitive
(
prim_cnode
);
auto
reduce_op
=
primitive
->
GetAttr
(
"op"
);
auto
group
=
primitive
->
GetAttr
(
"group"
)
->
ToString
();
// For sum operation, multiply constant tensor by number of devices
if
(
reduce_op
->
ToString
()
==
"sum"
)
{
unsigned
int
num_of_devices
;
// Get number of devices
if
(
!
CommManager
::
GetInstance
().
GetRankSize
(
group
,
&
num_of_devices
))
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to get num of devices for group ["
+
group
+
"]"
;
}
// Multiply constant by number of devices then return
std
::
vector
<
AnfNodePtr
>
mul_inputs
;
auto
constant_node
=
x
.
GetNode
(
node
);
MS_EXCEPTION_IF_NULL
(
constant_node
);
auto
constant_value_node
=
constant_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
constant_value_node
);
if
(
!
constant_value_node
->
value
()
->
isa
<
tensor
::
Tensor
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Expect the constant input for AllReduce to be a Tensor. Got "
+
constant_value_node
->
value
()
->
ToString
();
}
auto
constant_tensor
=
constant_value_node
->
value
()
->
cast
<
tensor
::
TensorPtr
>
();
auto
tensor_dtype
=
constant_tensor
->
Dtype
();
auto
num_of_device_node
=
NewValueNode
(
std
::
make_shared
<
tensor
::
Tensor
>
((
int64_t
)
num_of_devices
,
tensor_dtype
));
// Multiply nodes
auto
mul_prim
=
prim
::
GetPythonOps
(
"tensor_mul"
,
"mindspore.ops.functional"
);
MS_EXCEPTION_IF_NULL
(
mul_prim
);
mul_inputs
.
push_back
(
NewValueNode
(
mul_prim
));
mul_inputs
.
push_back
(
constant_node
);
mul_inputs
.
push_back
(
num_of_device_node
);
return
cur_func_graph
->
NewCNode
(
mul_inputs
);
}
else
{
return
x
.
GetNode
(
node
);
}
}
return
nullptr
;
}
};
}
// namespace irpass
}
// namespace irpass
}
// namespace opt
}
// namespace opt
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
81833943
...
@@ -133,6 +133,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
...
@@ -133,6 +133,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
incorporate_env_getitem_switch_
,
irpass
.
incorporate_env_getitem_switch_
,
irpass
.
new_env_get_item_
,
irpass
.
new_env_get_item_
,
irpass
.
depend_value_elim_
,
irpass
.
depend_value_elim_
,
irpass
.
all_reduce_const_elim_
,
});
});
opt
::
OptPassConfig
a_after_grad
=
opt
::
OptPassConfig
({
opt
::
OptPassConfig
a_after_grad
=
opt
::
OptPassConfig
({
irpass
.
inline_without_move_
,
irpass
.
inline_without_move_
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录