Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e702e0bc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e702e0bc
编写于
5月 20, 2020
作者:
Y
yujianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add tuple_getitem check for outputs of bn
上级
817b0e4a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
24 addition
and
7 deletion
+24
-7
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc
...tivate/ascend/ir_fission/batch_norm_grad_infer_fission.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
.../pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
+3
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc
...re_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc
+16
-7
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
e702e0bc
...
...
@@ -81,6 +81,7 @@
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
#include "pre_activate/ascend/ir_fission/addn_fission.h"
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
...
...
@@ -116,6 +117,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BatchNorm2BNInfer
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BatchNormGrad2BNInferGrad
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BatchNormGradInferFission
>
());
}
}
// namespace
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc
浏览文件 @
e702e0bc
...
...
@@ -34,6 +34,9 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
for
(
const
auto
&
node_index
:
manager
->
node_users
()[
node
])
{
AnfNodePtr
output
=
node_index
.
first
;
MS_EXCEPTION_IF_NULL
(
output
);
if
(
!
IsPrimitiveCNode
(
output
,
prim
::
kPrimTupleGetItem
))
{
continue
;
}
auto
tuple_getiterm_cnode
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getiterm_cnode
);
auto
index_node
=
tuple_getiterm_cnode
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
浏览文件 @
e702e0bc
...
...
@@ -274,6 +274,9 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
const
auto
&
output
:
bn_outputs
)
{
MS_EXCEPTION_IF_NULL
(
output
);
if
(
!
IsPrimitiveCNode
(
output
,
prim
::
kPrimTupleGetItem
))
{
continue
;
}
auto
tuple_getitem_cnode
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem_cnode
);
AnfNodePtr
index_node
=
tuple_getitem_cnode
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc
浏览文件 @
e702e0bc
...
...
@@ -32,7 +32,21 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
std
::
vector
<
size_t
>
mul_input_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
0
);
return
mul_input_shape
.
empty
()
||
(
mul_input_shape
.
size
()
==
1
&&
mul_input_shape
[
0
]
==
1
);
}
void
AddInputToOutput
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
old_cnode
,
const
AnfNodePtr
&
new_node
,
std
::
vector
<
AnfNodePtr
>
*
new_outputs
)
{
MS_EXCEPTION_IF_NULL
(
old_cnode
);
MS_EXCEPTION_IF_NULL
(
new_node
);
MS_EXCEPTION_IF_NULL
(
new_outputs
);
auto
node_to_output
=
old_cnode
->
input
(
kAccumIndex
+
1
);
MS_EXCEPTION_IF_NULL
(
node_to_output
);
AbstractBasePtrList
abstract_list
{
old_cnode
->
abstract
(),
node_to_output
->
abstract
()};
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
new_node
->
set_abstract
(
abstract_tuple
);
// Create Output
CreateMultipleOutputsOfAnfNode
(
func_graph
,
new_node
,
kFusedMulApplyMomentumOutputNum
,
new_outputs
);
}
}
// namespace
const
BaseRef
MomentumLossscaleFusion
::
DefinePattern
()
const
{
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
X0
=
std
::
make_shared
<
Var
>
();
...
...
@@ -80,15 +94,10 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
input_names_value
[
3
]
=
"x1"
;
input_names_value
.
emplace_back
(
"x2"
);
AnfAlgo
::
SetNodeAttr
(
kAttrInputNames
,
MakeValue
(
input_names_value
),
new_node
);
auto
node_to_output
=
cnode
->
input
(
kAccumIndex
+
1
);
MS_EXCEPTION_IF_NULL
(
node_to_output
);
AbstractBasePtrList
abstract_list
{
node
->
abstract
(),
node_to_output
->
abstract
()};
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
new_node
->
set_abstract
(
abstract_tuple
);
new_node
->
set_scope
(
node
->
scope
());
// Create Output
// Create Output
s
std
::
vector
<
AnfNodePtr
>
new_outputs
;
CreateMultipleOutputsOfAnfNode
(
func_graph
,
new_node
,
kFusedMulApplyMomentumOutputNum
,
&
new_outputs
);
AddInputToOutput
(
func_graph
,
cnode
,
new_node
,
&
new_outputs
);
if
(
new_outputs
.
size
()
!=
kFusedMulApplyMomentumOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to create outputs of "
<<
new_node
->
DebugString
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录