Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d4a82951
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4a82951
编写于
5月 15, 2020
作者:
H
huanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix confusionmulgrad fusion pass may create a loop
上级
92d196f0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
70 addition
and
15 deletion
+70
-15
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
...re_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
+12
-15
mindspore/ccsrc/pre_activate/common/helper.cc
mindspore/ccsrc/pre_activate/common/helper.cc
+53
-0
mindspore/ccsrc/pre_activate/common/helper.h
mindspore/ccsrc/pre_activate/common/helper.h
+3
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
d4a82951
...
...
@@ -51,6 +51,7 @@
#include "pre_activate/ascend/ir_fusion/derelu_fusion.h"
#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h"
#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h"
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h"
...
...
@@ -100,6 +101,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
MatmulBiasaddFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AddnFission
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
DereluFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ConfusionMulGradFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
TransposeTransDataFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BatchNorm2BNInfer
>
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
浏览文件 @
d4a82951
...
...
@@ -73,13 +73,16 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An
return
mul0
;
}
bool
QuitFusion
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
mul0_anf
,
const
AnfNodePtr
&
reduce_sum
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
bool
QuitFusion
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
mul0_anf
,
const
AnfNodePtr
&
mul1_anf
,
const
AnfNodePtr
&
reduce_sum
)
{
MS_EXCEPTION_IF_NULL
(
mul0_anf
);
MS_EXCEPTION_IF_NULL
(
mul1_anf
);
MS_EXCEPTION_IF_NULL
(
reduce_sum
);
if
(
!
mul0_anf
->
isa
<
CNode
>
())
{
if
(
!
mul0_anf
->
isa
<
CNode
>
()
||
!
mul1_anf
->
isa
<
CNode
>
()
)
{
return
true
;
}
auto
mul1
=
mul1_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
mul1
);
auto
mul0
=
mul0_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
mul0
);
...
...
@@ -88,20 +91,14 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
return
true
;
}
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
if
(
manager
->
node_users
().
find
(
reduce_sum
)
==
manager
->
node_users
().
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"node has no output in manager"
;
if
(
IsDepend
(
graph
,
mul0
->
input
(
1
),
reduce_sum
))
{
MS_LOG
(
INFO
)
<<
"mul0->input(1) depends on reduce_sum, quit fusion"
;
return
true
;
}
const
AnfNodeIndexSet
&
outputs_set
=
manager
->
node_users
()[
reduce_sum
];
auto
it
=
std
::
find_if
(
outputs_set
.
begin
(),
outputs_set
.
end
(),
[
&
mul0
](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
node_index
)
{
return
node_index
.
first
==
mul0
->
input
(
1
)
||
node_index
.
first
==
mul0
;
});
if
(
it
!=
outputs_set
.
end
())
{
MS_LOG
(
INFO
)
<<
"ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle"
;
if
(
IsDepend
(
graph
,
mul1
->
input
(
1
),
mul0
))
{
MS_LOG
(
INFO
)
<<
"mul1->input(1) depends on mul0, quit fusion"
;
return
true
;
}
return
false
;
}
}
// namespace
...
...
@@ -131,7 +128,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
MS_LOG
(
INFO
)
<<
"Mul0 do not exist, quit fusion"
;
return
nullptr
;
}
if
(
QuitFusion
(
graph
,
mul0
,
node
))
{
if
(
QuitFusion
(
graph
,
mul0
,
mul1
,
node
))
{
return
nullptr
;
}
...
...
mindspore/ccsrc/pre_activate/common/helper.cc
浏览文件 @
d4a82951
...
...
@@ -18,6 +18,9 @@
#include <string>
#include <unordered_set>
#include <algorithm>
#include <map>
#include <set>
#include <deque>
#include "utils/utils.h"
#include "utils/base_ref.h"
#include "session/anf_runtime_algorithm.h"
...
...
@@ -35,6 +38,56 @@ std::vector<int> Convert2Int(const std::vector<size_t> &v) {
return
result
;
}
bool
IsDepend
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
node1
);
MS_EXCEPTION_IF_NULL
(
node2
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
graph
->
get_return
());
std
::
map
<
AnfNodePtr
,
std
::
set
<
AnfNodePtr
>>
control_depend_map
;
for
(
auto
&
nd
:
node_list
)
{
if
(
AnfAlgo
::
CheckPrimitiveType
(
nd
,
prim
::
kPrimControlDepend
))
{
auto
control_depend
=
nd
->
cast
<
CNodePtr
>
();
auto
prior_node
=
control_depend
->
input
(
kControlDependPriorIndex
);
auto
behind_node
=
control_depend
->
input
(
kControlDependBehindIndex
);
auto
it
=
control_depend_map
.
find
(
behind_node
);
if
(
it
==
control_depend_map
.
end
())
{
control_depend_map
[
behind_node
]
=
std
::
set
<
AnfNodePtr
>
{
prior_node
};
}
else
{
it
->
second
.
insert
(
prior_node
);
}
}
}
FuncGraphManagerPtr
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
std
::
unordered_set
<
AnfNodePtr
>
seen_node
;
std
::
deque
<
AnfNodePtr
>
todo
{
node1
};
while
(
!
todo
.
empty
())
{
AnfNodePtr
node
=
todo
.
front
();
todo
.
pop_front
();
if
(
seen_node
.
count
(
node
)
>
0
||
!
manager
->
all_nodes
().
contains
(
node
))
{
continue
;
}
(
void
)
seen_node
.
insert
(
node
);
if
(
node
==
node2
)
{
return
true
;
}
if
(
node
->
isa
<
CNode
>
())
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
inputs
=
cnode
->
inputs
();
(
void
)
todo
.
insert
(
todo
.
end
(),
inputs
.
begin
(),
inputs
.
end
());
}
auto
it
=
control_depend_map
.
find
(
node
);
if
(
it
!=
control_depend_map
.
end
())
{
(
void
)
todo
.
insert
(
todo
.
end
(),
it
->
second
.
begin
(),
it
->
second
.
end
());
}
}
return
false
;
}
bool
UnVisited
(
const
BaseRef
&
n
)
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
n
))
{
AnfNodePtr
in
=
utils
::
cast
<
AnfNodePtr
>
(
n
);
...
...
mindspore/ccsrc/pre_activate/common/helper.h
浏览文件 @
d4a82951
...
...
@@ -111,6 +111,9 @@ enum ConvBn1Output {
std
::
vector
<
int
>
Convert2Int
(
const
std
::
vector
<
size_t
>
&
v
);
// check whether node1 depends on node2 or not
bool
IsDepend
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
);
bool
UnVisited
(
const
BaseRef
&
n
);
bool
Visited
(
const
BaseRef
&
n
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录