Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c7bda536
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c7bda536
编写于
6月 19, 2020
作者:
H
huanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix ConfusionSoftmaxGrad fusion pass work if the ReduceSum's attr of keep_dims set False
上级
14241786
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
26 addition
and
38 deletion
+26
-38
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc
..._activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc
+19
-34
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h
...e_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h
+6
-3
tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py
...t/gtest_input/pre_activate/confusion_softmax_grad_rule.py
+1
-1
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc
浏览文件 @
c7bda536
...
...
@@ -25,29 +25,8 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
void
SetAttrsForFusionNode
(
const
AnfNodePtr
&
sub_anf
,
const
AnfNodePtr
&
fusion_node
)
{
MS_EXCEPTION_IF_NULL
(
sub_anf
);
MS_EXCEPTION_IF_NULL
(
fusion_node
);
auto
sub
=
sub_anf
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sub
);
if
(
sub
->
size
()
!=
kSubInputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"Sub's size is not equal with 3"
;
}
auto
reduce_sum_anf
=
sub
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
reduce_sum_anf
);
auto
reduce_sum
=
reduce_sum_anf
->
cast
<
CNodePtr
>
();
if
(
reduce_sum
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Sub's second input is not a cnode"
;
}
AnfAlgo
::
CopyNodeAttr
(
kAttrAxis
,
reduce_sum
,
fusion_node
);
AnfAlgo
::
CopyNodeAttr
(
kAttrKeepDims
,
reduce_sum
,
fusion_node
);
}
}
// namespace
const
BaseRef
ConfusionSoftmaxGradRule
::
DefinePattern
()
const
{
return
VectorRef
(
{
prim
::
kPrimSub
,
input0_
,
VectorRef
({
prim
::
kPrimReduceSum
,
VectorRef
({
prim
::
kPrimMul
,
input1_
,
input0_
})})});
return
VectorRef
({
prim
::
kPrimSub
,
input0_
,
VectorRef
({
reduce_sum_
,
VectorRef
({
prim
::
kPrimMul
,
input1_
,
input0_
})})});
}
const
AnfNodePtr
ConfusionSoftmaxGradRule
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
...
...
@@ -55,22 +34,28 @@ const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, co
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
input0
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input0_
]);
auto
input1
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input1_
]);
MS_EXCEPTION_IF_NULL
(
input0
);
MS_EXCEPTION_IF_NULL
(
input1
);
AnfNodePtr
input0
=
GetAnfNodeByVar
(
equiv
,
input0_
);
AnfNodePtr
input1
=
GetAnfNodeByVar
(
equiv
,
input1_
);
AnfNodePtr
sum_anf
=
GetAnfNodeByVar
(
equiv
,
reduce_sum_
);
if
(
sum_anf
==
nullptr
||
!
sum_anf
->
isa
<
CNode
>
())
{
MS_LOG
(
WARNING
)
<<
"Matched ReduceSum is not a CNode!"
;
return
nullptr
;
}
if
(
!
GetBoolAttr
(
sum_anf
,
kAttrKeepDims
))
{
MS_LOG
(
INFO
)
<<
"ReduceSum's attr keep_dims should be true if do fusion. Otherwise the calculation will be wrong"
;
return
nullptr
;
}
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kConfusionSoftmaxGradOpName
);
MS_EXCEPTION_IF_NULL
(
prim
);
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
),
input0
,
input1
};
auto
confusion_softmax_grad
=
graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
confusion_softmax_grad
);
auto
types
=
{
AnfAlgo
::
GetOutputInferDataType
(
node
,
0
)};
auto
shapes
=
{
AnfAlgo
::
GetOutputInferShape
(
node
,
0
)};
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
confusion_softmax_grad
.
get
());
confusion_softmax_grad
->
set_scope
(
node
->
scope
());
SetAttrsForFusionNode
(
node
,
confusion_softmax_grad
);
return
confusion_softmax_grad
;
auto
fusion_node
=
graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
fusion_node
);
fusion_node
->
set_abstract
(
node
->
abstract
());
fusion_node
->
set_scope
(
node
->
scope
());
AnfAlgo
::
CopyNodeAttr
(
kAttrAxis
,
sum_anf
,
fusion_node
);
AnfAlgo
::
CopyNodeAttr
(
kAttrKeepDims
,
sum_anf
,
fusion_node
);
return
fusion_node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h
浏览文件 @
c7bda536
...
...
@@ -24,9 +24,11 @@ namespace opt {
class
ConfusionSoftmaxGradRule
:
public
PatternProcessPass
{
public:
explicit
ConfusionSoftmaxGradRule
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"confusion_softmax_grad_rule"
,
multigraph
),
input0_
(
std
::
make_shared
<
Var
>
()),
input1_
(
std
::
make_shared
<
Var
>
())
{}
:
PatternProcessPass
(
"confusion_softmax_grad_rule"
,
multigraph
)
{
input0_
=
std
::
make_shared
<
Var
>
();
input1_
=
std
::
make_shared
<
Var
>
();
reduce_sum_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimReduceSum
->
name
()));
}
~
ConfusionSoftmaxGradRule
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
...
...
@@ -34,6 +36,7 @@ class ConfusionSoftmaxGradRule : public PatternProcessPass {
private:
VarPtr
input0_
;
VarPtr
input1_
;
VarPtr
reduce_sum_
;
};
}
// namespace opt
}
// namespace mindspore
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/confusion_softmax_grad_rule.py
浏览文件 @
c7bda536
...
...
@@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from
mindspore.ops
import
operations
as
P
mul
=
P
.
Mul
()
reduce_sum
=
P
.
ReduceSum
()
reduce_sum
=
P
.
ReduceSum
(
keep_dims
=
True
)
sub
=
P
.
Sub
()
confusion_softmax_grad
=
Primitive
(
'ConfusionSoftmaxGrad'
)
make_tuple
=
Primitive
(
'make_tuple'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录