Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b56cbf18
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b56cbf18
编写于
4月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!725 Fix confusionmulgrad fusion pass cannot work
Merge pull request !725 from huanghui/r0.2-fix-confusionmulgrad
上级
e86ab6ce
bfd2afc0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
35 addition
and
8 deletion
+35
-8
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
...re_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
+35
-3
tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc
...tivate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc
+0
-5
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc
浏览文件 @
b56cbf18
...
...
@@ -72,6 +72,38 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An
}
return
mul0
;
}
bool
QuitFusion
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
mul0_anf
,
const
AnfNodePtr
&
reduce_sum
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
mul0_anf
);
MS_EXCEPTION_IF_NULL
(
reduce_sum
);
if
(
!
mul0_anf
->
isa
<
CNode
>
())
{
return
true
;
}
auto
mul0
=
mul0_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
mul0
);
// when network is _VirtualDatasetCell, quit fusion
if
(
mul0
->
fullname_with_scope
().
find
(
"network-_VirtualDatasetCell"
)
!=
std
::
string
::
npos
)
{
return
true
;
}
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
if
(
manager
->
node_users
().
find
(
reduce_sum
)
==
manager
->
node_users
().
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"node has no output in manager"
;
}
const
AnfNodeIndexSet
&
outputs_set
=
manager
->
node_users
()[
reduce_sum
];
auto
it
=
std
::
find_if
(
outputs_set
.
begin
(),
outputs_set
.
end
(),
[
&
mul0
](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
node_index
)
{
return
node_index
.
first
==
mul0
->
input
(
1
)
||
node_index
.
first
==
mul0
;
});
if
(
it
!=
outputs_set
.
end
())
{
MS_LOG
(
INFO
)
<<
"ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle"
;
return
true
;
}
return
false
;
}
}
// namespace
const
BaseRef
ConfusionMulGradFusion
::
DefinePattern
()
const
{
...
...
@@ -90,9 +122,6 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
auto
reduce_sum
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
reduce_sum
);
auto
mul1
=
reduce_sum
->
input
(
1
);
if
(
mul1
->
fullname_with_scope
().
find
(
"bert/encoder"
)
==
std
::
string
::
npos
)
{
return
nullptr
;
}
if
(
IsUsedByOthers
(
graph
,
mul1
))
{
MS_LOG
(
INFO
)
<<
"Mul1 is used by others, quit fusion!"
;
return
nullptr
;
...
...
@@ -102,6 +131,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
MS_LOG
(
INFO
)
<<
"Mul0 do not exist, quit fusion"
;
return
nullptr
;
}
if
(
QuitFusion
(
graph
,
mul0
,
node
))
{
return
nullptr
;
}
auto
fusion_node
=
CreateFusionNode
(
graph
,
reduce_sum
,
mul0
,
input3
);
std
::
vector
<
AnfNodePtr
>
fusion_node_outputs
;
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc
浏览文件 @
b56cbf18
...
...
@@ -32,11 +32,6 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon {
TEST_F
(
TestHWOptimizeConfusionMulGradFusion
,
test_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_confusion_mul_grad_fusion"
,
"before"
);
EXPECT_NE
(
g
,
nullptr
);
auto
bert_scope
=
std
::
make_shared
<
Scope
>
(
"bert/encoder"
);
for
(
auto
node
:
TopoSort
(
g
->
get_return
()))
{
node
->
set_scope
(
bert_scope
);
}
std
::
vector
<
int
>
shp
{
1
,
1
,
1
,
1
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录