Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7185961e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7185961e
编写于
4月 29, 2020
作者:
Y
YuJianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enable BatchNorm fusion pass
上级
4c32d7e6
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
292 addition
and
152 deletion
+292
-152
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+3
-3
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
.../pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
+179
-108
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h
...c/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h
+25
-16
mindspore/nn/layer/normalization.py
mindspore/nn/layer/normalization.py
+23
-20
tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc
...activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc
+54
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py
.../gtest_input/pre_activate/fused_batch_norm_fusion_test.py
+8
-5
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
7185961e
...
...
@@ -19,6 +19,7 @@
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ir_fission/bn_split.h"
#include "pre_activate/ascend/ir_fission/bn_grad_split.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h"
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
#include "pre_activate/pass/communication_op_fusion.h"
...
...
@@ -87,7 +88,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ReshapeTransposeFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
TransposeReshapeFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ClipByValueFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
FusedBatchNormFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
TopKSplit
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneWithDecayRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneFusion
>
());
...
...
@@ -193,8 +193,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto
optimizer
=
std
::
make_shared
<
GraphOptimizer
>
();
auto
ir_fusion_pm
=
std
::
make_shared
<
PassManager
>
(
"ir_fusion_pm"
);
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
B
n
Split
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BnGradSplit
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
B
atchNormGrad
Split
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
FusedBatchNormFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AddMemcpyAsync
>
());
if
(
context_ptr
->
ir_fusion_flag
())
{
AddAscendBackendOptionalIRFusion
(
ir_fusion_pm
.
get
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc
浏览文件 @
7185961e
...
...
@@ -23,6 +23,8 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
constexpr
size_t
kReplaceOutputIndex0
=
3
;
constexpr
size_t
kReplaceOutputIndex1
=
4
;
bool
IsC
(
const
BaseRef
&
n
)
{
if
(
utils
::
isa
<
AnfNodePtr
>
(
n
))
{
AnfNodePtr
in
=
utils
::
cast
<
AnfNodePtr
>
(
n
);
...
...
@@ -32,52 +34,6 @@ bool IsC(const BaseRef &n) {
return
false
;
}
AnfNodePtr
GetBatchNormNode
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
depend_cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
depend_cnode
);
CheckCNodeInputSize
(
depend_cnode
,
kDependInputNum
);
AnfNodePtr
assign_sub
=
depend_cnode
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
assign_sub
);
auto
assign_sub_cnode
=
assign_sub
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
assign_sub_cnode
);
CheckCNodeInputSize
(
assign_sub_cnode
,
kAssignSubInputNum
);
AnfNodePtr
mul
=
assign_sub_cnode
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
mul
);
auto
mul_cnode
=
mul
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
mul_cnode
);
CheckCNodeInputSize
(
mul_cnode
,
kMulInputNum
);
AnfNodePtr
sub
=
mul_cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
sub
);
auto
sub_cnode
=
sub
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
sub_cnode
);
CheckCNodeInputSize
(
sub_cnode
,
kSubInputNum
);
AnfNodePtr
tuple_getitem
=
sub_cnode
->
input
(
2
);
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
auto
tuple_getitem_cnode
=
tuple_getitem
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem_cnode
);
CheckCNodeInputSize
(
tuple_getitem_cnode
,
kTupleGetitemInputNum
);
return
tuple_getitem_cnode
->
input
(
1
);
}
bool
CompareTupleGetitem
(
const
AnfNodePtr
&
n1
,
const
AnfNodePtr
&
n2
)
{
MS_EXCEPTION_IF_NULL
(
n1
);
MS_EXCEPTION_IF_NULL
(
n2
);
auto
n1_cnode
=
n1
->
cast
<
CNodePtr
>
();
auto
n2_cnode
=
n2
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
n1_cnode
);
MS_EXCEPTION_IF_NULL
(
n2_cnode
);
auto
index_input1
=
n1_cnode
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
index_input1
);
auto
value_node1
=
index_input1
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node1
);
auto
index_input2
=
n2_cnode
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
index_input2
);
auto
value_node2
=
index_input2
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node2
);
return
GetValue
<
int
>
(
value_node1
->
value
())
<
GetValue
<
int
>
(
value_node2
->
value
());
}
void
GetBNOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
bn
,
std
::
vector
<
AnfNodePtr
>
*
bn_outputs
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
bn
);
...
...
@@ -92,54 +48,35 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect
MS_EXCEPTION_IF_NULL
(
output
);
bn_outputs
->
push_back
(
output
);
}
sort
(
bn_outputs
->
begin
(),
bn_outputs
->
end
(),
CompareTupleGetitem
);
}
}
// namespace
const
BaseRef
FusedBatchNormFusion
::
DefinePattern
()
const
{
const
auto
prim_batch_norm
=
std
::
make_shared
<
Primitive
>
(
kBatchNormOpName
);
std
::
shared_ptr
<
Var
>
Xs
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
index0
=
std
::
make_shared
<
CondVar
>
(
IsC
);
VarPtr
index1
=
std
::
make_shared
<
CondVar
>
(
IsC
);
VarPtr
index2
=
std
::
make_shared
<
CondVar
>
(
IsC
);
VectorRef
batch_norm
=
VectorRef
({
prim_batch_norm
,
data_input_var0_
,
data_input_var1_
,
data_input_var2
_
,
Xs
});
VectorRef
batch_norm
=
VectorRef
({
batch_norm_var_
,
data_input0_var_
,
data_input1_var_
,
data_input2_var
_
,
Xs
});
VectorRef
tuple_getitem0
=
VectorRef
({
prim
::
kPrimTupleGetItem
,
batch_norm
,
index0
});
VectorRef
tuple_getitem1
=
VectorRef
({
prim
::
kPrimTupleGetItem
,
batch_norm
,
index1
});
VectorRef
tuple_getitem2
=
VectorRef
({
prim
::
kPrimTupleGetItem
,
batch_norm
,
index2
});
VectorRef
sub0
=
VectorRef
({
prim
::
kPrimSub
,
variable_input
_var0
_
,
tuple_getitem1
});
VectorRef
sub1
=
VectorRef
({
prim
::
kPrimSub
,
variable_input
_var1
_
,
tuple_getitem2
});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
sub0
,
constant_input
_var0
_
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
sub1
,
constant_input
_var1
_
});
VectorRef
assign_sub0
=
VectorRef
({
prim
::
kPrimAssignSub
,
variable_input
_var0
_
,
mul0
});
VectorRef
assign_sub1
=
VectorRef
({
prim
::
kPrimAssignSub
,
variable_input
_var1
_
,
mul1
});
VectorRef
sub0
=
VectorRef
({
prim
::
kPrimSub
,
variable_input
0_var
_
,
tuple_getitem1
});
VectorRef
sub1
=
VectorRef
({
prim
::
kPrimSub
,
variable_input
1_var
_
,
tuple_getitem2
});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
sub0
,
constant_input
0_var
_
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
sub1
,
constant_input
1_var
_
});
VectorRef
assign_sub0
=
VectorRef
({
prim
::
kPrimAssignSub
,
variable_input
0_var
_
,
mul0
});
VectorRef
assign_sub1
=
VectorRef
({
prim
::
kPrimAssignSub
,
variable_input
1_var
_
,
mul1
});
VectorRef
depend0
=
VectorRef
({
prim
::
kPrimDepend
,
tuple_getitem0
,
assign_sub0
});
return
VectorRef
({
prim
::
kPrimDepend
,
depend0
,
assign_sub1
});
}
abstract
::
AbstractTuplePtr
FusedBatchNormFusion
::
CreateAbstractOfFusedBatchNorm
(
const
EquivPtr
&
equiv
,
const
AnfNodePtr
&
bn
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
MS_EXCEPTION_IF_NULL
(
bn
);
auto
variable_input0
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
variable_input_var0_
]);
MS_EXCEPTION_IF_NULL
(
variable_input0
);
auto
variable_input1
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
variable_input_var1_
]);
MS_EXCEPTION_IF_NULL
(
variable_input1
);
auto
bn_abstract_tuple
=
dyn_cast
<
abstract
::
AbstractTuple
>
(
bn
->
abstract
());
MS_EXCEPTION_IF_NULL
(
bn_abstract_tuple
);
if
(
bn_abstract_tuple
->
elements
().
size
()
!=
kBnOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"The abstract size of node bn must be "
<<
kBnOutputNum
<<
", but it is "
<<
bn_abstract_tuple
->
elements
().
size
();
}
AbstractBasePtrList
fused_bn_abstract_list
{
bn_abstract_tuple
->
elements
()[
0
],
variable_input0
->
abstract
(),
variable_input1
->
abstract
(),
bn_abstract_tuple
->
elements
()[
3
],
bn_abstract_tuple
->
elements
()[
4
]};
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
fused_bn_abstract_list
);
return
abstract_tuple
;
}
ValuePtr
FusedBatchNormFusion
::
GetFactor
(
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
constant_input
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
constant_input_var0_
]);
auto
iter_constant_input0
=
(
*
equiv
).
find
(
constant_input0_var_
);
if
(
iter_constant_input0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the constant_input0 var after matched."
;
}
auto
constant_input
=
utils
::
cast
<
AnfNodePtr
>
(
iter_constant_input0
->
second
);
MS_EXCEPTION_IF_NULL
(
constant_input
);
if
(
!
constant_input
->
isa
<
ValueNode
>
())
{
return
nullptr
;
...
...
@@ -158,53 +95,187 @@ ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
return
MakeValue
(
tensor_data
[
0
]);
}
const
AnfNodePtr
FusedBatchNormFusion
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
AnfNodePtr
FusedBatchNormFusion
::
CreateBNTrainingReduce
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
// Set inputs
auto
data_input0
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
data_input_var0_
]);
MS_EXCEPTION_IF_NULL
(
data_input0
);
auto
data_input1
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
data_input_var1_
]);
// Set input to create node
auto
iter_data_input0
=
(
*
equiv
).
find
(
data_input0_var_
);
if
(
iter_data_input0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the data_input0 var after matched."
;
}
std
::
vector
<
AnfNodePtr
>
bn_training_reduce_inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
kBNTrainingReduceOpName
)),
utils
::
cast
<
AnfNodePtr
>
(
iter_data_input0
->
second
)};
auto
bn_training_reduce
=
func_graph
->
NewCNode
(
bn_training_reduce_inputs
);
MS_EXCEPTION_IF_NULL
(
bn_training_reduce
);
bn_training_reduce
->
set_scope
(
node
->
scope
());
// Set abstract
auto
iter_data_input1
=
(
*
equiv
).
find
(
data_input1_var_
);
if
(
iter_data_input1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the data_input1 var after matched."
;
}
auto
data_input1
=
utils
::
cast
<
AnfNodePtr
>
(
iter_data_input1
->
second
);
MS_EXCEPTION_IF_NULL
(
data_input1
);
auto
data_input2
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
data_input_var2_
]);
auto
iter_data_input2
=
(
*
equiv
).
find
(
data_input2_var_
);
if
(
iter_data_input2
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the data_input2 var after matched."
;
}
auto
data_input2
=
utils
::
cast
<
AnfNodePtr
>
(
iter_data_input2
->
second
);
MS_EXCEPTION_IF_NULL
(
data_input2
);
auto
variable_input0
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
variable_input_var0_
]);
AbstractBasePtrList
abstract_list
{
data_input1
->
abstract
(),
data_input2
->
abstract
()};
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
bn_training_reduce
->
set_abstract
(
abstract_tuple
);
return
bn_training_reduce
;
}
void
FusedBatchNormFusion
::
GetBNTrainingUpdateInputs
(
const
EquivPtr
&
equiv
,
const
std
::
vector
<
AnfNodePtr
>
&
bn_training_reduce_outputs
,
std
::
vector
<
AnfNodePtr
>
*
bn_training_update_inputs
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
MS_EXCEPTION_IF_NULL
(
bn_training_update_inputs
);
auto
iter_data_input0
=
(
*
equiv
).
find
(
data_input0_var_
);
if
(
iter_data_input0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the data_input0 var after matched."
;
}
auto
iter_data_input1
=
(
*
equiv
).
find
(
data_input1_var_
);
if
(
iter_data_input1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the data_input1 var after matched."
;
}
auto
iter_data_input2
=
(
*
equiv
).
find
(
data_input2_var_
);
if
(
iter_data_input2
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the data_input2 var after matched."
;
}
auto
iter_variable_input0
=
(
*
equiv
).
find
(
variable_input0_var_
);
if
(
iter_variable_input0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the variable_input0 var after matched."
;
}
auto
iter_variable_input1
=
(
*
equiv
).
find
(
variable_input1_var_
);
if
(
iter_variable_input1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the variable_input1 var after matched."
;
}
if
(
bn_training_reduce_outputs
.
size
()
!=
kBNTrainingReduceOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"The output size of node bn_training_reduce must be "
<<
kBNTrainingReduceOutputNum
<<
", but it is "
<<
bn_training_reduce_outputs
.
size
();
}
*
bn_training_update_inputs
=
{
NewValueNode
(
std
::
make_shared
<
Primitive
>
(
kBNTrainingUpdateOpName
)),
utils
::
cast
<
AnfNodePtr
>
(
iter_data_input0
->
second
),
bn_training_reduce_outputs
[
0
],
bn_training_reduce_outputs
[
1
],
utils
::
cast
<
AnfNodePtr
>
(
iter_data_input1
->
second
),
utils
::
cast
<
AnfNodePtr
>
(
iter_data_input2
->
second
),
utils
::
cast
<
AnfNodePtr
>
(
iter_variable_input0
->
second
),
utils
::
cast
<
AnfNodePtr
>
(
iter_variable_input1
->
second
),
};
}
void
FusedBatchNormFusion
::
GetBNTrainingUpdateAbstractList
(
const
EquivPtr
&
equiv
,
const
AnfNodePtr
&
bn
,
std
::
vector
<
AbstractBasePtr
>
*
abstract_list
)
const
{
MS_EXCEPTION_IF_NULL
(
equiv
);
MS_EXCEPTION_IF_NULL
(
bn
);
MS_EXCEPTION_IF_NULL
(
abstract_list
);
auto
bn_abstract_tuple
=
dyn_cast
<
abstract
::
AbstractTuple
>
(
bn
->
abstract
());
MS_EXCEPTION_IF_NULL
(
bn_abstract_tuple
);
if
(
bn_abstract_tuple
->
elements
().
size
()
<
kBnOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"The abstract size of node bn must not be less than "
<<
kBnOutputNum
<<
", but it is "
<<
bn_abstract_tuple
->
elements
().
size
();
}
auto
iter_variable_input0
=
(
*
equiv
).
find
(
variable_input0_var_
);
if
(
iter_variable_input0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the variable_input0 var after matched."
;
}
auto
variable_input0
=
utils
::
cast
<
AnfNodePtr
>
(
iter_variable_input0
->
second
);
MS_EXCEPTION_IF_NULL
(
variable_input0
);
auto
variable_input1
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
variable_input_var1_
]);
auto
iter_variable_input1
=
(
*
equiv
).
find
(
variable_input1_var_
);
if
(
iter_variable_input1
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the variable_input1 var after matched."
;
}
auto
variable_input1
=
utils
::
cast
<
AnfNodePtr
>
(
iter_variable_input1
->
second
);
MS_EXCEPTION_IF_NULL
(
variable_input1
);
std
::
vector
<
AnfNodePtr
>
fused_bn_inputs
=
{
NewValueNode
(
prim
::
kPrimFusedBatchNorm
),
data_input0
,
data_input1
,
data_input2
,
variable_input0
,
variable_input1
};
auto
fused_bn
=
func_graph
->
NewCNode
(
fused_bn_inputs
);
fused_bn
->
set_scope
(
node
->
scope
());
MS_EXCEPTION_IF_NULL
(
fused_bn
);
*
abstract_list
=
{
bn_abstract_tuple
->
elements
()[
0
],
variable_input0
->
abstract
(),
variable_input1
->
abstract
(),
bn_abstract_tuple
->
elements
()[
1
],
bn_abstract_tuple
->
elements
()[
2
]};
}
AnfNodePtr
FusedBatchNormFusion
::
CreateBNTrainingUpdate
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
,
const
std
::
vector
<
AnfNodePtr
>
&
bn_training_reduce_outputs
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
// Set input
std
::
vector
<
AnfNodePtr
>
bn_training_update_inputs
;
GetBNTrainingUpdateInputs
(
equiv
,
bn_training_reduce_outputs
,
&
bn_training_update_inputs
);
auto
bn_training_update
=
func_graph
->
NewCNode
(
bn_training_update_inputs
);
MS_EXCEPTION_IF_NULL
(
bn_training_update
);
// Set abstract
AnfNodePtr
bn
=
GetBatchNormNode
(
node
);
fused_bn
->
set_abstract
(
CreateAbstractOfFusedBatchNorm
(
equiv
,
bn
));
// Set attr
AnfAlgo
::
CopyNodeAttr
(
kAttrEpsilon
,
bn
,
fused_bn
);
auto
iter_batch_norm
=
(
*
equiv
).
find
(
batch_norm_var_
);
if
(
iter_batch_norm
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the batch_norm var after matched."
;
}
AnfNodePtr
bn
=
utils
::
cast
<
AnfNodePtr
>
(
iter_batch_norm
->
second
);
MS_EXCEPTION_IF_NULL
(
bn
);
AbstractBasePtrList
abstract_list
;
GetBNTrainingUpdateAbstractList
(
equiv
,
bn
,
&
abstract_list
);
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
bn_training_update
->
set_abstract
(
abstract_tuple
);
AnfAlgo
::
CopyNodeAttr
(
kAttrEpsilon
,
bn
,
bn_training_update
);
ValuePtr
factor
=
GetFactor
(
equiv
);
if
(
factor
==
nullptr
)
{
return
nullptr
;
}
AnfAlgo
::
SetNodeAttr
(
kAttrMomentum
,
factor
,
fused_bn
);
// Replace old nodes with outputs of fused_bn
std
::
vector
<
AnfNodePtr
>
fused_bn_outputs
;
CreateMultipleOutputsOfAnfNode
(
func_graph
,
fused_bn
,
kBnOutputNum
,
&
fused_bn_outputs
);
if
(
fused_bn_outputs
.
size
()
!=
kBnOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"The output size of node bn must be "
<<
kBnOutputNum
<<
", but it is "
<<
fused_bn_outputs
.
size
();
AnfAlgo
::
SetNodeAttr
(
kAttrFactor
,
factor
,
bn_training_update
);
AnfAlgo
::
SetNodeAttr
(
kAttrIsRef
,
MakeValue
(
true
),
bn_training_update
);
bn_training_update
->
set_scope
(
node
->
scope
());
return
bn_training_update
;
}
const
AnfNodePtr
FusedBatchNormFusion
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
equiv
);
MS_EXCEPTION_IF_NULL
(
node
);
AnfNodePtr
bn_training_reduce
=
CreateBNTrainingReduce
(
func_graph
,
node
,
equiv
);
std
::
vector
<
AnfNodePtr
>
bn_training_reduce_outputs
;
CreateMultipleOutputsOfAnfNode
(
func_graph
,
bn_training_reduce
,
kBNTrainingReduceOutputNum
,
&
bn_training_reduce_outputs
);
AnfNodePtr
bn_training_update
=
CreateBNTrainingUpdate
(
func_graph
,
node
,
equiv
,
bn_training_reduce_outputs
);
if
(
bn_training_update
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Create BNTrainingUpdate failed for bn node "
<<
node
->
DebugString
();
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
bn_training_update_outputs
;
CreateMultipleOutputsOfAnfNode
(
func_graph
,
bn_training_update
,
kBNTrainingUpdateOutputNum
,
&
bn_training_update_outputs
);
if
(
bn_training_update_outputs
.
size
()
<
kBNTrainingUpdateOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"The output size of node bn must be "
<<
kBNTrainingUpdateOutputNum
<<
", but it is "
<<
bn_training_update_outputs
.
size
();
}
// Replace old bn outputs with new outputs
auto
iter_batch_norm
=
(
*
equiv
).
find
(
batch_norm_var_
);
if
(
iter_batch_norm
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the batch_norm var after matched."
;
}
AnfNodePtr
bn
=
utils
::
cast
<
AnfNodePtr
>
(
iter_batch_norm
->
second
);
std
::
vector
<
AnfNodePtr
>
bn_outputs
;
GetBNOutput
(
func_graph
,
bn
,
&
bn_outputs
);
if
(
bn_outputs
.
size
()
!=
kBnOutputNum
)
{
MS_LOG
(
EXCEPTION
)
<<
"The output size of node bn must be "
<<
kBnOutputNum
<<
", but it is "
<<
bn_outputs
.
size
();
}
auto
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
(
void
)
manager
->
Replace
(
bn_outputs
[
3
],
fused_bn_outputs
[
3
]);
(
void
)
manager
->
Replace
(
bn_outputs
[
4
],
fused_bn_outputs
[
4
]);
return
fused_bn_outputs
[
0
];
for
(
const
auto
&
output
:
bn_outputs
)
{
MS_EXCEPTION_IF_NULL
(
output
);
auto
tuple_getitem_cnode
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem_cnode
);
AnfNodePtr
index_node
=
tuple_getitem_cnode
->
input
(
kInputNodeOutputIndexInTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
index_node
);
auto
value_node
=
index_node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
int
index
=
GetValue
<
int
>
(
value_node
->
value
());
if
(
index
==
kReplaceOutputIndex0
||
index
==
kReplaceOutputIndex1
)
{
(
void
)
manager
->
Replace
(
output
,
bn_training_update_outputs
[
index
]);
}
}
return
bn_training_update_outputs
[
0
];
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h
浏览文件 @
7185961e
...
...
@@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -26,29 +27,37 @@ class FusedBatchNormFusion : public PatternProcessPass {
public:
explicit
FusedBatchNormFusion
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"fused_batch_norm_fusion"
,
multigraph
),
data_input_var0_
(
std
::
make_shared
<
Var
>
()),
data_input_var1_
(
std
::
make_shared
<
Var
>
()),
data_input_var2_
(
std
::
make_shared
<
Var
>
()),
variable_input_var0_
(
std
::
make_shared
<
Var
>
()),
variable_input_var1_
(
std
::
make_shared
<
Var
>
()),
constant_input_var0_
(
std
::
make_shared
<
Var
>
()),
constant_input_var1_
(
std
::
make_shared
<
Var
>
())
{}
data_input0_var_
(
std
::
make_shared
<
Var
>
()),
data_input1_var_
(
std
::
make_shared
<
Var
>
()),
data_input2_var_
(
std
::
make_shared
<
Var
>
()),
variable_input0_var_
(
std
::
make_shared
<
Var
>
()),
variable_input1_var_
(
std
::
make_shared
<
Var
>
()),
constant_input0_var_
(
std
::
make_shared
<
Var
>
()),
constant_input1_var_
(
std
::
make_shared
<
Var
>
()),
batch_norm_var_
(
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimBatchNorm
->
name
())))
{}
~
FusedBatchNormFusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
abstract
::
AbstractTuplePtr
CreateAbstractOfFusedBatchNorm
(
const
EquivPtr
&
equiv
,
const
AnfNodePtr
&
bn
)
const
;
AnfNodePtr
CreateBNTrainingReduce
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
;
void
GetBNTrainingUpdateInputs
(
const
EquivPtr
&
equiv
,
const
std
::
vector
<
AnfNodePtr
>
&
bn_training_reduce_outputs
,
std
::
vector
<
AnfNodePtr
>
*
bn_training_update_inputs
)
const
;
void
GetBNTrainingUpdateAbstractList
(
const
EquivPtr
&
equiv
,
const
AnfNodePtr
&
bn
,
std
::
vector
<
AbstractBasePtr
>
*
abstract_list
)
const
;
AnfNodePtr
CreateBNTrainingUpdate
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
,
const
std
::
vector
<
AnfNodePtr
>
&
bn_training_reduce_outputs
)
const
;
ValuePtr
GetFactor
(
const
EquivPtr
&
equiv
)
const
;
VarPtr
data_input_var0_
;
VarPtr
data_input_var1_
;
VarPtr
data_input_var2_
;
VarPtr
variable_input_var0_
;
VarPtr
variable_input_var1_
;
VarPtr
constant_input_var0_
;
VarPtr
constant_input_var1_
;
VarPtr
data_input0_var_
;
VarPtr
data_input1_var_
;
VarPtr
data_input2_var_
;
VarPtr
variable_input0_var_
;
VarPtr
variable_input1_var_
;
VarPtr
constant_input0_var_
;
VarPtr
constant_input1_var_
;
VarPtr
batch_norm_var_
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/nn/layer/normalization.py
浏览文件 @
7185961e
...
...
@@ -62,6 +62,7 @@ class _BatchNorm(Cell):
self
.
beta
=
Parameter
(
initializer
(
beta_init
,
num_features
),
name
=
"beta"
,
requires_grad
=
affine
)
self
.
group
=
check_int_positive
(
device_num_each_group
)
self
.
is_global
=
False
if
self
.
group
!=
1
:
self
.
rank_id
=
get_rank
()
self
.
rank_size
=
get_group_size
()
...
...
@@ -80,15 +81,18 @@ class _BatchNorm(Cell):
self
.
cast
=
P
.
Cast
()
self
.
dtype
=
P
.
DType
()
self
.
reshape
=
P
.
Reshape
()
self
.
is_ascend
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
if
context
.
get_context
(
"enable_ge"
):
self
.
is_ge_backend
=
True
self
.
momentum
=
Tensor
(
1.0
-
momentum
,
mstype
.
float32
)
self
.
bn_train
=
P
.
BatchNorm
(
is_training
=
True
,
epsilon
=
self
.
eps
)
else
:
self
.
is_ge_backend
=
False
self
.
momentum
=
1.0
-
momentum
if
self
.
is_ge_backend
or
self
.
is_ascend
:
self
.
bn_train
=
P
.
BatchNorm
(
is_training
=
True
,
epsilon
=
self
.
eps
)
else
:
self
.
bn_train
=
P
.
FusedBatchNorm
(
mode
=
1
,
epsilon
=
self
.
eps
,
momentum
=
self
.
momentum
)
...
...
@@ -140,24 +144,23 @@ class _BatchNorm(Cell):
def
construct
(
self
,
x
):
if
self
.
training
and
self
.
use_batch_statistics
:
if
self
.
is_ge_backend
:
if
self
.
is_global
:
axes
,
re_shape
=
_shape_infer
(
F
.
shape
(
x
),
self
.
num_features
)
y
=
self
.
_global_sync
(
x
,
axes
,
re_shape
)
else
:
y
,
batch_mean
,
batch_var
,
_
,
_
=
\
self
.
bn_train
(
x
,
self
.
gamma
,
self
.
beta
,
None
,
None
)
mean_sub
=
self
.
sub_mean
(
self
.
moving_mean
,
batch_mean
)
temp_mean
=
self
.
mul_mean
(
mean_sub
,
self
.
momentum
)
mean_sub2
=
self
.
sub_var
(
self
.
moving_variance
,
batch_var
)
temp_variance
=
self
.
mul_var
(
mean_sub2
,
self
.
momentum
)
y
=
F
.
depend
(
y
,
self
.
assign_sub_mean
(
self
.
moving_mean
,
temp_mean
))
y
=
F
.
depend
(
y
,
self
.
assign_sub_var
(
self
.
moving_variance
,
temp_variance
))
if
self
.
is_ge_backend
and
self
.
is_global
:
axes
,
re_shape
=
_shape_infer
(
F
.
shape
(
x
),
self
.
num_features
)
y
=
self
.
_global_sync
(
x
,
axes
,
re_shape
)
elif
self
.
is_ge_backend
or
self
.
is_ascend
:
y
,
batch_mean
,
batch_var
,
_
,
_
=
\
self
.
bn_train
(
x
,
self
.
gamma
,
self
.
beta
,
None
,
None
)
mean_sub
=
self
.
sub_mean
(
self
.
moving_mean
,
batch_mean
)
temp_mean
=
self
.
mul_mean
(
mean_sub
,
self
.
momentum
)
mean_sub2
=
self
.
sub_var
(
self
.
moving_variance
,
batch_var
)
temp_variance
=
self
.
mul_var
(
mean_sub2
,
self
.
momentum
)
y
=
F
.
depend
(
y
,
self
.
assign_sub_mean
(
self
.
moving_mean
,
temp_mean
))
y
=
F
.
depend
(
y
,
self
.
assign_sub_var
(
self
.
moving_variance
,
temp_variance
))
else
:
y
=
self
.
bn_train
(
x
,
self
.
gamma
,
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc
0 → 100644
浏览文件 @
7185961e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
namespace
mindspore
{
namespace
opt
{
class
TestHWFusedBatchNormFusion
:
public
BackendCommon
{
public:
TestHWFusedBatchNormFusion
()
:
get_py_fun_
(
"gtest_input.pre_activate.fused_batch_norm_fusion_test"
,
true
)
{}
~
TestHWFusedBatchNormFusion
()
override
=
default
;
UT
::
PyFuncGraphFetcher
get_py_fun_
;
};
TEST_F
(
TestHWFusedBatchNormFusion
,
test_fused_batch_norm_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_fused_batch_norm_fusion"
,
"before"
);
EXPECT_NE
(
g
,
nullptr
);
std
::
vector
<
int
>
shp_x
{
32
,
64
,
112
,
112
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp_x
);
std
::
vector
<
int
>
shp_y
{
64
};
auto
y_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp_y
);
AbstractBasePtrList
args_spec_list
{
x_abstract
};
for
(
size_t
i
=
0
;
i
<
6
;
++
i
)
{
args_spec_list
.
push_back
(
y_abstract
);
}
auto
kg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
FusedBatchNormFusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_fused_batch_norm_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
}
// namespace opt
}
// namespace mindspore
\ No newline at end of file
tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py
浏览文件 @
7185961e
...
...
@@ -24,7 +24,8 @@ make_tuple = Primitive('make_tuple')
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
depend
=
Primitive
(
'depend'
)
BatchNorm
=
P
.
BatchNorm
()
FusedBatchNorm
=
P
.
FusedBatchNorm
()
BNTrainingReduce
=
Primitive
(
'BNTrainingReduce'
)
BNTrainingUpdate
=
Primitive
(
'BNTrainingUpdate'
)
constant0
=
Tensor
(
0.1
,
mstype
.
float32
)
constant1
=
Tensor
(
0.1
,
mstype
.
float32
)
...
...
@@ -40,7 +41,7 @@ class FnDict:
return
self
.
fnDict
[
name
]
def
useless_
test_fused_batch_norm_fusion
(
tag
):
def
test_fused_batch_norm_fusion
(
tag
):
fns
=
FnDict
()
@
fns
...
...
@@ -60,9 +61,11 @@ def useless_test_fused_batch_norm_fusion(tag):
@
fns
def
after
(
input0
,
input1
,
input2
,
input3
,
input4
,
var0
,
var1
):
fused_batch_norm
=
FusedBatchNorm
(
input0
,
input1
,
input2
,
var0
,
var1
)
outputs
=
make_tuple
(
tuple_getitem
(
fused_batch_norm
,
0
),
tuple_getitem
(
fused_batch_norm
,
3
),
tuple_getitem
(
fused_batch_norm
,
4
))
bn_training_reduce
=
BNTrainingReduce
(
input0
)
bn_training_update
=
BNTrainingUpdate
(
input0
,
tuple_getitem
(
bn_training_reduce
,
0
),
tuple_getitem
(
bn_training_reduce
,
1
),
input1
,
input2
,
var0
,
var1
)
outputs
=
make_tuple
(
tuple_getitem
(
bn_training_update
,
0
),
tuple_getitem
(
bn_training_update
,
3
),
tuple_getitem
(
bn_training_update
,
4
))
output
=
tuple_getitem
(
outputs
,
0
)
return
make_tuple
(
output
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录