Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2215e326
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2215e326
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1419 remove old buffer fusion pass
Merge pull request !1419 from Etone.Chan/Resnet50
上级
728afd23
42d724d8
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
51 addition
and
956 deletion
+51
-956
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-3
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
.../ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
+0
-800
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
...e/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
+0
-73
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h
...csrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h
+7
-0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc
...tivate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc
+5
-1
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc
...ivate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc
+3
-1
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h
...src/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h
+1
-1
mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.cc
mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.cc
+0
-35
mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.h
mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.h
+0
-33
tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
...p/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
+33
-9
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
2215e326
...
...
@@ -62,7 +62,6 @@
#include "pre_activate/pass/common_subexpression_elimination.h"
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
#include "pre_activate/ascend/format_type/check_consistency.h"
#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h"
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h"
...
...
@@ -317,14 +316,14 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
optimizer
->
AddPassManager
(
other_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
// buffer fusion
AscendBackendUBFusionOptimization
(
kernel_graph
);
if
(
save_graphs
)
{
std
::
string
file_path
=
save_graphs_path
+
"/"
+
"hwopt_d_end"
+
"_graph_"
+
std
::
to_string
(
kernel_graph
->
graph_id
())
+
".ir"
;
DumpIR
(
file_path
,
kernel_graph
,
true
);
DumpIRProto
(
kernel_graph
,
"after_hwopt_"
+
std
::
to_string
(
kernel_graph
->
graph_id
()));
}
// buffer fusion
AscendBackendUBFusionOptimization
(
kernel_graph
);
}
void
AscendBackendUBFusionOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
已删除
100644 → 0
浏览文件 @
728afd23
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h"
#include <vector>
#include <tuple>
#include <utility>
#include <unordered_set>
#include <unordered_map>
#include <deque>
#include <memory>
#include <string>
#include <algorithm>
#include <iterator>
#include "kernel/kernel_fusion.h"
#include "debug/anf_ir_dump.h"
#include "session/anf_runtime_algorithm.h"
#include "operator/ops.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#include "pre_activate/common/fusion_id_allocator.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
const
int8_t
MAX_PATTERN_SIZE
=
7
;
const
int8_t
MIN_PATTERN_SIZE
=
2
;
const
int8_t
ELTWISE_INPUT_SIZE
=
2
;
const
int8_t
ELTWISE_USE
=
1
;
const
int8_t
MULTI_ELTWISE_USE
=
2
;
const
int8_t
MAX_MULTI_ELTWISE_SIZE
=
4
;
const
int8_t
MAX_PURE_BUFFER_SUCC_SIZE
=
3
;
constexpr
auto
kOpAttrFusionId
=
"fusion_id"
;
#ifdef DEBUG
std
::
string
GetFusionTypeName
(
const
kernel
::
FusionType
&
type
)
{
switch
(
type
)
{
case
kernel
::
FusionType
::
COMMREDUCE
:
return
"COMMREDUCE"
;
case
kernel
::
FusionType
::
SEGMENT
:
return
"SEGMENT"
;
case
kernel
::
FusionType
::
ELEMWISE
:
return
"ELEMWISE"
;
case
kernel
::
FusionType
::
CONVLUTION
:
return
"CONVLUTION"
;
case
kernel
::
FusionType
::
OPAQUE
:
return
"OPAQUE"
;
default:
return
"OPAQUE"
;
}
}
void
DumpFusionScopeInfo
(
const
kernel
::
FusionScopeInfo
&
info
)
{
MS_LOG
(
INFO
)
<<
"=== Dump FusionScopeInfo start id: "
<<
info
.
scope_id
;
for
(
auto
&
node
:
info
.
input_nodes
)
{
MS_LOG
(
INFO
)
<<
"=== Input: "
<<
node
->
DebugString
();
}
for
(
auto
&
node
:
info
.
output_nodes
)
{
MS_LOG
(
INFO
)
<<
"=== Output: "
<<
node
->
DebugString
();
}
for
(
auto
&
node
:
info
.
compute_nodes
)
{
MS_LOG
(
INFO
)
<<
"=== Compute: ("
<<
node
->
DebugString
()
<<
")-("
<<
GetFusionTypeName
(
AnfAlgo
::
GetFusionType
(
node
))
<<
")"
;
}
MS_LOG
(
INFO
)
<<
"=== Dump FusionScopeInfo end"
;
}
#endif
bool
CheckEltWiseNode
(
FuncGraphManager
*
manager
,
std
::
unordered_set
<
AnfNodePtr
>
*
record
,
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
record
);
auto
user_nodes
=
manager
->
node_users
()[
node
];
return
(
AnfAlgo
::
GetKernelType
(
node
)
==
KernelType
::
TBE_KERNEL
&&
AnfAlgo
::
GetFusionType
(
node
)
==
kernel
::
FusionType
::
ELEMWISE
&&
(
user_nodes
.
size
()
<=
ELTWISE_USE
||
record
->
size
()
==
0
));
}
// Common method to check for predecessors and successors in a fusion pattern
std
::
tuple
<
bool
,
CNodePtr
>
FindPredAndSuccEltWiseNodes
(
const
int8_t
&
max_size
,
FuncGraphManager
*
manager
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_set
,
std
::
deque
<
AnfNodePtr
>
*
todo
,
std
::
unordered_set
<
AnfNodePtr
>
*
record
,
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
visited_set
);
MS_EXCEPTION_IF_NULL
(
todo
);
MS_EXCEPTION_IF_NULL
(
record
);
MS_EXCEPTION_IF_NULL
(
node
);
CNodePtr
new_node
=
node
;
if
(
new_node
->
inputs
().
size
()
<
ELTWISE_INPUT_SIZE
)
{
return
std
::
make_tuple
(
false
,
new_node
);
}
int8_t
index
=
1
;
auto
&
users
=
manager
->
node_users
();
while
(
CheckEltWiseNode
(
manager
,
record
,
new_node
))
{
(
void
)
record
->
insert
(
new_node
);
(
void
)
visited_set
->
insert
(
new_node
);
(
void
)
todo
->
insert
(
todo
->
end
(),
new_node
->
inputs
().
begin
()
+
1
,
new_node
->
inputs
().
end
());
auto
cnode
=
new_node
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
cnode
->
isa
<
CNode
>
())
{
return
std
::
make_tuple
(
false
,
new_node
);
}
new_node
=
cnode
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
new_node
);
if
(
!
AnfAlgo
::
IsRealKernel
(
new_node
)
||
new_node
->
inputs
().
size
()
<
ELTWISE_INPUT_SIZE
||
users
[(
new_node
)].
size
()
>=
MULTI_ELTWISE_USE
||
visited_set
->
find
(
new_node
)
!=
visited_set
->
end
())
{
return
std
::
make_tuple
(
false
,
new_node
);
}
if
(
index
>=
max_size
)
{
break
;
}
index
++
;
}
return
std
::
make_tuple
(
true
,
new_node
);
}
std
::
tuple
<
bool
,
CNodePtr
>
MatchGeneralPattern
(
FuncGraphManager
*
manager
,
std
::
unordered_set
<
AnfNodePtr
>
*
record
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_set
,
std
::
deque
<
AnfNodePtr
>
*
todo
,
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
record
);
MS_EXCEPTION_IF_NULL
(
visited_set
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
todo
);
CNodePtr
new_node
=
node
;
auto
&
users
=
manager
->
node_users
();
if
(
users
[(
new_node
)].
size
()
>=
MULTI_ELTWISE_USE
)
{
return
std
::
make_tuple
(
false
,
new_node
);
}
(
void
)
record
->
insert
(
node
);
(
void
)
visited_set
->
insert
(
node
);
(
void
)
todo
->
insert
(
todo
->
end
(),
new_node
->
inputs
().
begin
()
+
1
,
new_node
->
inputs
().
end
());
if
(
node
->
inputs
().
size
()
<
2
)
{
return
std
::
make_tuple
(
false
,
new_node
);
}
// only check the first real input, will check all
auto
cnode
=
node
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
cnode
->
isa
<
CNode
>
())
{
return
std
::
make_tuple
(
false
,
new_node
);
}
new_node
=
cnode
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
new_node
);
if
(
!
AnfAlgo
::
IsRealKernel
(
new_node
)
||
users
[(
new_node
)].
size
()
>=
MULTI_ELTWISE_USE
||
visited_set
->
find
(
new_node
)
!=
visited_set
->
end
())
{
return
std
::
make_tuple
(
false
,
new_node
);
}
return
std
::
make_tuple
(
true
,
new_node
);
}
CNodePtr
FindFusionAnfNode
(
FuncGraphManager
*
manager
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_set
,
std
::
unordered_set
<
AnfNodePtr
>
*
record
,
std
::
deque
<
AnfNodePtr
>
*
todo
,
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
visited_set
);
MS_EXCEPTION_IF_NULL
(
record
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
todo
);
// find fusion pattern predecessor nodes
auto
ret
=
FindPredAndSuccEltWiseNodes
(
MAX_MULTI_ELTWISE_SIZE
,
manager
,
visited_set
,
todo
,
record
,
node
);
auto
new_node
=
std
::
get
<
1
>
(
ret
);
auto
node_use_size
=
manager
->
node_users
()[
new_node
].
size
();
if
(
!
std
::
get
<
0
>
(
ret
)
||
(
record
->
size
()
>
1
&&
node_use_size
>
1
)
||
record
->
size
()
>=
MAX_MULTI_ELTWISE_SIZE
||
AnfAlgo
::
GetKernelType
(
new_node
)
!=
KernelType
::
TBE_KERNEL
)
{
return
new_node
;
}
// key of fusion precessor
auto
node_fusion_type
=
AnfAlgo
::
GetFusionType
(
new_node
);
switch
(
node_fusion_type
)
{
case
kernel
::
FusionType
::
COMMREDUCE
:
case
kernel
::
FusionType
::
SEGMENT
:
ret
=
MatchGeneralPattern
(
manager
,
record
,
visited_set
,
todo
,
new_node
);
new_node
=
std
::
get
<
1
>
(
ret
);
if
(
!
std
::
get
<
0
>
(
ret
))
{
return
new_node
;
}
break
;
case
kernel
::
FusionType
::
ELEMWISE
:
return
new_node
;
// -fallthrough to default and return
case
kernel
::
FusionType
::
CONVLUTION
:
(
void
)
record
->
insert
(
new_node
);
default:
(
void
)
visited_set
->
insert
(
new_node
);
if
(
new_node
!=
nullptr
)
{
(
void
)
todo
->
insert
(
todo
->
end
(),
new_node
->
inputs
().
begin
()
+
1
,
new_node
->
inputs
().
end
());
}
return
new_node
;
}
// find fusion pattern successor nodes
ret
=
FindPredAndSuccEltWiseNodes
(
MAX_PURE_BUFFER_SUCC_SIZE
,
manager
,
visited_set
,
todo
,
record
,
new_node
);
return
std
::
get
<
1
>
(
ret
);
}
CNodePtr
CreateFusionOp
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs_list
,
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
,
const
std
::
vector
<
AnfNodePtr
>
&
anf_nodes
,
session
::
KernelGraph
*
kernel_graph
)
{
MS_LOG
(
DEBUG
)
<<
"Start Create FusionOp Kernel"
;
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
string
fusion_op_name
=
"FusionOp"
;
for
(
auto
node
:
anf_nodes
)
{
fusion_op_name
+=
'_'
+
AnfAlgo
::
GetCNodeName
(
node
);
}
auto
fusion_op
=
std
::
make_shared
<
Primitive
>
(
fusion_op_name
);
MS_EXCEPTION_IF_NULL
(
fusion_op
);
std
::
vector
<
std
::
string
>
input_names
;
for
(
uint8_t
i
=
0
;
i
<
inputs_list
.
size
();
i
++
)
{
input_names
.
emplace_back
(
"input"
+
std
::
to_string
(
i
));
}
std
::
vector
<
std
::
string
>
output_names
;
for
(
uint8_t
i
=
0
;
i
<
outputs_list
.
size
();
i
++
)
{
output_names
.
emplace_back
(
"output"
+
std
::
to_string
(
i
));
}
ValuePtr
input_names_v
=
MakeValue
(
input_names
);
ValuePtr
output_names_v
=
MakeValue
(
output_names
);
fusion_op
->
set_attr
(
"input_names"
,
input_names_v
);
fusion_op
->
set_attr
(
"output_names"
,
output_names_v
);
std
::
vector
<
AnfNodePtr
>
fusion_inputs_list
=
inputs_list
;
auto
value_node
=
std
::
make_shared
<
ValueNode
>
(
fusion_op
);
(
void
)
fusion_inputs_list
.
insert
(
fusion_inputs_list
.
begin
(),
value_node
);
auto
buffer_fusion_kernel
=
kernel_graph
->
NewCNode
(
fusion_inputs_list
);
if
(
buffer_fusion_kernel
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"New FusionOp kernel failed!"
;
}
buffer_fusion_kernel
->
set_scope
((
anf_nodes
.
back
())
->
scope
());
return
buffer_fusion_kernel
;
}
kernel
::
KernelBuildInfoPtr
CreateFusionOpKernelInfo
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs_list
,
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
)
{
MS_LOG
(
DEBUG
)
<<
"Start Create Kernel Info"
;
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
// inputs format and data type
std
::
vector
<
std
::
string
>
inputs_format
;
std
::
vector
<
TypeId
>
inputs_data_type
;
for
(
const
auto
&
input
:
inputs_list
)
{
auto
real_input
=
AnfAlgo
::
VisitKernel
(
input
,
0
);
inputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
real_input
.
first
,
real_input
.
second
));
inputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
real_input
.
first
,
real_input
.
second
));
}
// outputs format and data type
std
::
vector
<
std
::
string
>
outputs_format
;
std
::
vector
<
TypeId
>
outputs_data_type
;
for
(
const
auto
&
output
:
outputs_list
)
{
if
(
AnfAlgo
::
GetCNodeName
(
output
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
tuple_getitem
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
}
else
{
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
output
,
0
));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
output
,
0
));
}
}
builder
.
SetInputsFormat
(
inputs_format
);
builder
.
SetInputsDeviceType
(
inputs_data_type
);
builder
.
SetOutputsFormat
(
outputs_format
);
builder
.
SetOutputsDeviceType
(
outputs_data_type
);
builder
.
SetKernelType
(
KernelType
::
TBE_KERNEL
);
return
builder
.
Build
();
}
AnfNodePtr
CreateTupleGetItem
(
const
AnfNodePtr
&
buffer_fusion_kernel
,
session
::
KernelGraph
*
kernel_graph
,
size_t
output_index
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
vector
<
AnfNodePtr
>
tuple_getitem_inputs_list
;
auto
value
=
std
::
make_shared
<
ValueNode
>
(
prim
::
kPrimTupleGetItem
);
MS_EXCEPTION_IF_NULL
(
value
);
auto
idx
=
NewValueNode
(
SizeToInt
(
output_index
));
MS_EXCEPTION_IF_NULL
(
idx
);
int
temp
=
SizeToInt
(
output_index
);
auto
imm
=
std
::
make_shared
<
Int32Imm
>
(
temp
);
auto
abstract_scalar
=
std
::
make_shared
<
abstract
::
AbstractScalar
>
(
imm
);
idx
->
set_abstract
(
abstract_scalar
);
tuple_getitem_inputs_list
.
push_back
(
value
);
tuple_getitem_inputs_list
.
push_back
(
buffer_fusion_kernel
);
tuple_getitem_inputs_list
.
push_back
(
idx
);
auto
tuple_item
=
kernel_graph
->
NewCNode
(
tuple_getitem_inputs_list
);
MS_EXCEPTION_IF_NULL
(
tuple_item
);
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
buffer_fusion_kernel
,
output_index
)},
{
AnfAlgo
::
GetOutputInferShape
(
buffer_fusion_kernel
,
output_index
)},
tuple_item
.
get
());
return
tuple_item
;
}
void
ReplaceInputNodeInOtherFusionScope
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
AnfNodePtr
&
output_item
,
const
AnfNodePtr
&
replace_item
)
{
for
(
int32_t
id
=
fusion_id
+
1
;
id
<=
SizeToInt
(
buffer_fusion_infos
->
size
());
++
id
)
{
auto
itr
=
std
::
find
((
*
buffer_fusion_infos
)[
id
].
inputs_list
.
begin
(),
(
*
buffer_fusion_infos
)[
id
].
inputs_list
.
end
(),
output_item
);
if
(
itr
!=
(
*
buffer_fusion_infos
)[
id
].
inputs_list
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"replace input of other pattern, id = "
<<
id
;
*
itr
=
replace_item
;
}
}
}
void
ReplaceOldNode
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
AnfNodePtr
&
buffer_fusion_kernel
,
session
::
KernelGraph
*
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
buffer_fusion_info
=
(
*
buffer_fusion_infos
)[
fusion_id
];
if
(
buffer_fusion_info
.
outputs_list
.
size
()
==
1
)
{
// single output
(
void
)
manager
->
Replace
(
buffer_fusion_info
.
outputs_list
[
0
],
buffer_fusion_kernel
);
ReplaceInputNodeInOtherFusionScope
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion_info
.
outputs_list
[
0
],
buffer_fusion_kernel
);
}
else
{
// multiple output
for
(
size_t
index
=
0
;
index
<
buffer_fusion_info
.
outputs_list
.
size
();
++
index
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
index
);
(
void
)
manager
->
Replace
(
buffer_fusion_info
.
outputs_list
[
index
],
tuple_item
);
ReplaceInputNodeInOtherFusionScope
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion_info
.
outputs_list
[
index
],
tuple_item
);
}
}
}
void
GetFusionScopeComputeNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
nodes
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
node
->
isa
<
CNode
>
())
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
AnfAlgo
::
IsRealCNodeKernel
(
cnode
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
cnode
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
cnode
,
kOpAttrFusionId
);
(
*
buffer_fusion_infos
)[
fusion_id
].
anf_nodes
.
push_back
(
cnode
);
}
}
}
void
GetFusionScopeInputNodeList
(
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
auto
fusion_id
=
buffer_fusion_info
.
first
;
auto
fusion_info
=
buffer_fusion_info
.
second
;
for
(
const
auto
&
node
:
fusion_info
.
anf_nodes
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
for
(
size_t
idx
=
1
;
idx
<
cnode
->
inputs
().
size
();
++
idx
)
{
auto
real_input
=
AnfAlgo
::
VisitKernel
(
cnode
->
input
(
idx
),
0
);
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
real_input
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
if
(
std
::
find
((
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
begin
(),
(
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
end
(),
cnode
->
input
(
idx
))
==
(
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
push_back
(
cnode
->
input
(
idx
));
}
}
}
}
}
}
bool
TupleGetitemNodeCompare
(
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
)
{
MS_EXCEPTION_IF_NULL
(
node1
);
MS_EXCEPTION_IF_NULL
(
node2
);
auto
getitem1
=
node1
->
cast
<
CNodePtr
>
();
auto
getitem2
=
node2
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
getitem1
);
MS_EXCEPTION_IF_NULL
(
getitem2
);
auto
output_idx1
=
GetValue
<
int
>
(
GetValueNode
(
getitem1
->
input
(
2
)));
auto
output_idx2
=
GetValue
<
int
>
(
GetValueNode
(
getitem2
->
input
(
2
)));
return
output_idx1
<
output_idx2
;
}
void
GetFusionScopeOutputNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
auto
fusion_id
=
buffer_fusion_info
.
first
;
auto
fusion_info
=
buffer_fusion_info
.
second
;
for
(
const
auto
&
node
:
fusion_info
.
anf_nodes
)
{
if
(
AnfAlgo
::
GetOutputTensorNum
(
node
)
==
1
)
{
for
(
auto
use_node
:
manager
->
node_users
()[
node
])
{
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
use_node
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
node
);
break
;
}
}
}
else
{
int
prev_idx
=
0
;
std
::
vector
<
AnfNodePtr
>
tuple_getitem_nodes
;
std
::
transform
(
manager
->
node_users
()[
node
].
begin
(),
manager
->
node_users
()[
node
].
end
(),
std
::
back_inserter
(
tuple_getitem_nodes
),
[](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
use_node
)
{
return
use_node
.
first
;
});
std
::
sort
(
tuple_getitem_nodes
.
begin
(),
tuple_getitem_nodes
.
end
(),
TupleGetitemNodeCompare
);
for
(
auto
getitem
:
tuple_getitem_nodes
)
{
auto
getitem_ptr
=
getitem
->
cast
<
CNodePtr
>
();
auto
input2
=
getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
for
(
int
stub_idx
=
prev_idx
;
stub_idx
<
output_idx
;
++
stub_idx
)
{
auto
stub_node
=
CreateTupleGetItem
(
node
,
kernel_graph
,
IntToSize
(
stub_idx
));
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
stub_node
);
}
prev_idx
=
output_idx
+
1
;
for
(
auto
item_use_node
:
manager
->
node_users
()[
getitem
])
{
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
item_use_node
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
getitem
);
break
;
}
}
}
}
}
}
}
void
SetFusionOpRefInfos
(
session
::
KernelGraph
*
kernel_graph
,
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
,
const
AnfNodePtr
&
fusion_kernel
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
size_t
idx
=
0
;
idx
<
outputs_list
.
size
();
++
idx
)
{
auto
output
=
outputs_list
[
idx
];
if
(
output
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
output
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
real_output
=
AnfAlgo
::
VisitKernel
(
output
,
0
);
auto
output_cnode
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
output_cnode
);
auto
input2
=
output_cnode
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
session
::
AnfWithOutIndex
out_pair
(
real_output
.
first
,
output_idx
);
if
(
kernel_graph
->
IsInRefOutputMap
(
out_pair
))
{
auto
origin_pair
=
kernel_graph
->
GetRefCorrespondOutput
(
out_pair
);
session
::
AnfWithOutIndex
fusion_final_pair
(
fusion_kernel
,
idx
);
kernel_graph
->
AddRefCorrespondPairs
(
fusion_final_pair
,
origin_pair
);
}
}
else
{
session
::
AnfWithOutIndex
out_pair
(
output
,
0
);
if
(
kernel_graph
->
IsInRefOutputMap
(
out_pair
))
{
auto
origin_pair
=
kernel_graph
->
GetRefCorrespondOutput
(
out_pair
);
session
::
AnfWithOutIndex
fusion_final_pair
(
fusion_kernel
,
idx
);
kernel_graph
->
AddRefCorrespondPairs
(
fusion_final_pair
,
origin_pair
);
}
}
}
}
}
// namespace
void
BufferFusion
::
SetRecordFusionId
(
const
std
::
unordered_set
<
AnfNodePtr
>
&
record
)
{
auto
id
=
fusion_id_allocator
.
AllocateFusionId
();
for
(
auto
node
:
record
)
{
fusion_id_allocator
.
SetFusionId
(
node
,
id
);
}
}
void
BufferFusion
::
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
conv
=
cnode
->
input
(
1
);
if
(
conv
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
conv
)
==
prim
::
kPrimConv2D
->
name
())
{
std
::
vector
<
int
>
output_used_num
{
SizeToInt
(
manager
->
node_users
()[
conv
].
size
())};
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
conv
};
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
}
void
BufferFusion
::
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
getitem
=
relu_input
->
cast
<
CNodePtr
>
();
auto
bnupdate
=
getitem
->
input
(
1
);
if
(
bnupdate
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
bnupdate
)
==
kBNTrainingUpdateOpName
)
{
std
::
vector
<
int
>
output_used_num
(
AnfAlgo
::
GetOutputTensorNum
(
bnupdate
),
0
);
for
(
auto
out_getitem
:
manager
->
node_users
()[
bnupdate
])
{
auto
out_getitem_ptr
=
out_getitem
.
first
->
cast
<
CNodePtr
>
();
auto
input2
=
out_getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
output_used_num
[
output_idx
]
=
SizeToInt
(
manager
->
node_users
()[
out_getitem
.
first
].
size
());
}
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
}
void
BufferFusion
::
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
add
=
relu_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add
);
auto
tuple_getitem
=
add
->
input
(
1
);
if
(
tuple_getitem
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
tuple_getitem
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
getitem
=
tuple_getitem
->
cast
<
CNodePtr
>
();
auto
bnupdate
=
getitem
->
input
(
1
);
if
(
bnupdate
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
bnupdate
)
==
kBNTrainingUpdateOpName
)
{
std
::
vector
<
int
>
output_used_num
(
AnfAlgo
::
GetOutputTensorNum
(
bnupdate
),
0
);
for
(
auto
out_getitem
:
manager
->
node_users
()[
bnupdate
])
{
auto
out_getitem_ptr
=
out_getitem
.
first
->
cast
<
CNodePtr
>
();
auto
input2
=
out_getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
output_used_num
[
output_idx
]
=
SizeToInt
(
manager
->
node_users
()[
out_getitem
.
first
].
size
());
}
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
relu_input
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
}
}
void
BufferFusion
::
MatchDepthwiseConvRelu
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
,
bool
is_order
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
if
(
is_order
)
{
// DepthwiseConvolution--->Elemwise
auto
depthwise_conv
=
cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
depthwise_conv
);
if
(
cnode
->
isa
<
CNode
>
()
&&
IsPrimitiveCNode
(
depthwise_conv
,
prim
::
kPrimDepthwiseConv2dNative
))
{
std
::
vector
<
int
>
output_used_num
{
SizeToInt
(
manager
->
node_users
()[
depthwise_conv
].
size
())};
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
depthwise_conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
depthwise_conv
};
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
}
else
{
// Elemwise-->DepthwiseConvolution
auto
relu
=
cnode
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
relu
);
if
(
cnode
->
isa
<
CNode
>
()
&&
(
IsPrimitiveCNode
(
relu
,
prim
::
kPrimRelu
)
||
IsPrimitiveCNode
(
relu
,
prim
::
kPrimReluV2
)))
{
std
::
vector
<
int
>
output_used_num
{
SizeToInt
(
manager
->
node_users
()[
relu
].
size
())};
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
relu
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
relu
};
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
}
}
void
BufferFusion
::
MatchMatmulEltwise
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
std
::
vector
<
int
>
output_used_num
{
SizeToInt
(
manager
->
node_users
()[
relu_input
].
size
())};
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
relu_input
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
relu_input
};
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
void
BufferFusion
::
MatchOpNamePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
.
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
)
||
fusion_id_allocator
.
HasFusionIdAttr
(
node
)
||
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimReturn
))
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kBNTrainingReduceOpName
)
{
MatchConvBnreduce
(
cnode
,
kernel_graph
,
candidate_fusion
);
}
else
if
(
AnfAlgo
::
GetKernelType
(
cnode
)
==
KernelType
::
TBE_KERNEL
&&
AnfAlgo
::
GetFusionType
(
cnode
)
==
kernel
::
FusionType
::
ELEMWISE
)
{
auto
eltwise_input
=
cnode
->
input
(
1
);
if
(
eltwise_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
CheckPrimitiveType
(
eltwise_input
,
prim
::
kPrimMatMul
))
{
MatchMatmulEltwise
(
cnode
,
eltwise_input
,
kernel_graph
,
candidate_fusion
);
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kReluV2OpName
||
AnfAlgo
::
CheckPrimitiveType
(
cnode
,
prim
::
kPrimRelu
))
{
if
(
eltwise_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
CheckPrimitiveType
(
eltwise_input
,
prim
::
kPrimTensorAdd
))
{
MatchBnupdateAddRelu
(
cnode
,
eltwise_input
,
kernel_graph
,
candidate_fusion
);
}
else
if
(
eltwise_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
CheckPrimitiveType
(
eltwise_input
,
prim
::
kPrimTupleGetItem
))
{
MatchBnupdateRelu
(
cnode
,
eltwise_input
,
kernel_graph
,
candidate_fusion
);
}
else
if
(
eltwise_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
CheckPrimitiveType
(
eltwise_input
,
prim
::
kPrimDepthwiseConv2dNative
))
{
MatchDepthwiseConvRelu
(
cnode
,
kernel_graph
,
candidate_fusion
,
true
);
}
}
}
else
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
prim
::
kPrimDepthwiseConv2dNative
->
name
())
{
MatchDepthwiseConvRelu
(
cnode
,
kernel_graph
,
candidate_fusion
,
false
);
}
}
}
void
BufferFusion
::
MatchFusionTypePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
return_node
=
kernel_graph
.
get_return
();
MS_EXCEPTION_IF_NULL
(
return_node
);
if
(
return_node
->
inputs
().
size
()
<=
1
)
{
return
;
}
std
::
deque
<
AnfNodePtr
>
todo
;
todo
.
push_back
(
return_node
->
input
(
1
));
std
::
unordered_set
<
AnfNodePtr
>
visited_set
;
while
(
!
todo
.
empty
())
{
auto
node
=
todo
.
front
();
MS_EXCEPTION_IF_NULL
(
node
);
todo
.
pop_front
();
std
::
unordered_set
<
AnfNodePtr
>
record
;
if
(
visited_set
.
find
(
node
)
!=
visited_set
.
end
()
||
fusion_id_allocator
.
HasFusionIdAttr
(
node
))
{
continue
;
}
// Only fuse real cnode
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
))
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
!=
nullptr
)
{
(
void
)
todo
.
insert
(
todo
.
end
(),
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
());
}
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
// cnode maybe updated
cnode
=
FindFusionAnfNode
(
manager
.
get
(),
&
visited_set
,
&
record
,
&
todo
,
cnode
);
if
(
record
.
size
()
>=
MIN_PATTERN_SIZE
&&
record
.
size
()
<=
MAX_PATTERN_SIZE
)
{
candidate_fusion
->
push_back
(
record
);
SetRecordFusionId
(
record
);
}
if
(
record
.
find
(
cnode
)
==
record
.
end
())
{
todo
.
push_back
(
cnode
);
}
// no node matched
if
(
record
.
size
()
==
0
)
{
(
void
)
visited_set
.
insert
(
node
);
}
(
void
)
todo
.
insert
(
todo
.
end
(),
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
());
}
}
void
BufferFusion
::
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
GetFusionScopeComputeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeInputNodeList
(
*
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeOutputNodeList
(
kernel_graph
,
buffer_fusion_infos
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
buffer_fusion_info
.
second
.
kernel_build_info
=
CreateFusionOpKernelInfo
(
buffer_fusion_info
.
second
.
inputs_list
,
buffer_fusion_info
.
second
.
outputs_list
);
}
}
bool
BufferFusion
::
FuseBufferFusionPattern
(
session
::
KernelGraph
*
kernel_graph
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
bool
change
=
false
;
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
buffer_fusion_infos
;
buffer_fusion_infos
.
clear
();
GetBufferFusionInfo
(
kernel_graph
,
&
buffer_fusion_infos
);
std
::
vector
<
mindspore
::
kernel
::
FusionScopeInfo
>
fusion_scope_infos
;
for
(
auto
&
buffer_fusion_info
:
buffer_fusion_infos
)
{
mindspore
::
kernel
::
FusionScopeInfo
fusion_scope_info
;
fusion_scope_info
.
scope_id
=
buffer_fusion_info
.
first
;
fusion_scope_info
.
input_nodes
=
buffer_fusion_info
.
second
.
inputs_list
;
fusion_scope_info
.
compute_nodes
=
buffer_fusion_info
.
second
.
anf_nodes
;
fusion_scope_info
.
output_nodes
=
buffer_fusion_info
.
second
.
outputs_list
;
fusion_scope_infos
.
push_back
(
fusion_scope_info
);
#ifdef DEBUG
DumpFusionScopeInfo
(
fusion_scope_info
);
#endif
}
auto
kernel_mods
=
mindspore
::
kernel
::
KernelFusion
(
fusion_scope_infos
);
std
::
vector
<
int32_t
>
fusion_ids
;
for
(
auto
&
buffer_fusion_info
:
buffer_fusion_infos
)
{
MS_LOG
(
DEBUG
)
<<
"anf node size: "
<<
buffer_fusion_info
.
second
.
anf_nodes
.
size
()
<<
", inputs_list size: "
<<
buffer_fusion_info
.
second
.
inputs_list
.
size
()
<<
", outputs list size: "
<<
buffer_fusion_info
.
second
.
outputs_list
.
size
();
fusion_ids
.
push_back
(
buffer_fusion_info
.
first
);
}
// Replace fusion op from return to head
std
::
sort
(
fusion_ids
.
begin
(),
fusion_ids
.
end
());
for
(
auto
&
fusion_id
:
fusion_ids
)
{
// Get kernel mod when supporting tbe
if
(
kernel_mods
.
find
(
fusion_id
)
==
kernel_mods
.
end
()
||
kernel_mods
[
fusion_id
]
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"fusion id: "
<<
fusion_id
<<
", fusion op compiling failed"
;
continue
;
}
change
=
ReplaceFusionOp
(
&
buffer_fusion_infos
,
fusion_id
,
kernel_mods
[
fusion_id
],
kernel_graph
);
}
MS_LOG
(
DEBUG
)
<<
"End Buffer Fusion"
;
return
change
;
}
bool
BufferFusion
::
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
)
{
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
return_node
=
kernel_graph
.
get_return
();
MS_EXCEPTION_IF_NULL
(
return_node
);
if
(
return_node
->
inputs
().
size
()
<=
1
)
{
return
false
;
}
MS_LOG
(
DEBUG
)
<<
"MatchBufferFusionPattern start..."
;
FusedNodeRecord
candidate_fusion
;
MatchOpNamePattern
(
kernel_graph
,
&
candidate_fusion
);
MatchFusionTypePattern
(
kernel_graph
,
&
candidate_fusion
);
if
(
candidate_fusion
.
empty
())
{
return
false
;
}
MS_LOG
(
DEBUG
)
<<
"MatchBufferFusionPattern Success..."
;
return
true
;
}
bool
BufferFusion
::
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
{
auto
buffer_fusion_info
=
(
*
buffer_fusion_infos
)[
fusion_id
];
auto
buffer_fusion
=
CreateFusionOp
(
buffer_fusion_info
.
inputs_list
,
buffer_fusion_info
.
outputs_list
,
buffer_fusion_info
.
anf_nodes
,
kernel_graph
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
buffer_fusion_info
.
kernel_build_info
,
buffer_fusion
.
get
());
// Set abstract of fusion_op node
std
::
vector
<
TypeId
>
types
;
std
::
vector
<
std
::
vector
<
size_t
>>
shapes
;
for
(
const
auto
&
out_node
:
buffer_fusion_info
.
outputs_list
)
{
for
(
size_t
idx
=
0
;
idx
<
AnfAlgo
::
GetOutputTensorNum
(
out_node
);
++
idx
)
{
types
.
push_back
(
AnfAlgo
::
GetOutputInferDataType
(
out_node
,
idx
));
shapes
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
out_node
,
idx
));
}
}
if
(
types
.
empty
()
||
shapes
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"buffer_fusion_info.outputs_list is empty"
;
return
false
;
}
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
buffer_fusion
.
get
());
AnfAlgo
::
SetKernelMod
(
kernel_ptr
,
buffer_fusion
.
get
());
SetFusionOpRefInfos
(
kernel_graph
,
buffer_fusion_info
.
outputs_list
,
buffer_fusion
);
ReplaceOldNode
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion
,
kernel_graph
);
return
true
;
}
bool
BufferFusion
::
Run
(
const
FuncGraphPtr
&
graph
)
{
bool
changed
=
false
;
MS_EXCEPTION_IF_NULL
(
graph
);
auto
kernel_graph
=
graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
fusion_id_allocator
.
Init
();
if
(
MatchBufferFusionPattern
(
*
kernel_graph
))
{
changed
=
FuseBufferFusionPattern
(
kernel_graph
.
get
());
}
// clear fusion_id attr
for
(
auto
&
node
:
graph
->
nodes
())
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
())
{
AnfAlgo
::
EraseNodeAttr
(
kAttrFusionId
,
node
);
}
}
return
changed
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
已删除
100644 → 0
浏览文件 @
728afd23
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_BUFFER_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_BUFFER_FUSION_H_
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
#include "pre_activate/common/fusion_id_allocator.h"
#include "device/kernel_info.h"
#include "kernel/kernel.h"
#include "session/kernel_graph.h"
namespace
mindspore
{
namespace
opt
{
struct
BufferFusionInfo_t
{
std
::
vector
<
AnfNodePtr
>
anf_nodes
;
std
::
vector
<
AnfNodePtr
>
inputs_list
;
std
::
vector
<
AnfNodePtr
>
outputs_list
;
kernel
::
KernelBuildInfoPtr
kernel_build_info
;
};
using
FusedNodeRecord
=
std
::
vector
<
std
::
unordered_set
<
AnfNodePtr
>>
;
class
BufferFusion
:
public
Pass
{
public:
BufferFusion
()
:
Pass
(
"buffer_fusion"
)
{}
~
BufferFusion
()
override
=
default
;
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
void
SetRecordFusionId
(
const
std
::
unordered_set
<
AnfNodePtr
>
&
record
);
void
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchDepthwiseConvRelu
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
,
bool
is_order
);
void
MatchMatmulEltwise
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchOpNamePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchFusionTypePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
;
bool
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
;
bool
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
);
bool
FuseBufferFusionPattern
(
session
::
KernelGraph
*
kernel_graph
)
const
;
FusionIdAllocator
fusion_id_allocator
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_BUFFER_FUSION_H_
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h
浏览文件 @
2215e326
...
...
@@ -37,6 +37,13 @@ const int8_t ELTWISE_USE = 1;
const
int8_t
MAX_ELTWISE_SIZE
=
6
;
using
FusedNodeRecord
=
std
::
vector
<
std
::
unordered_set
<
AnfNodePtr
>>
;
struct
BufferFusionInfo_t
{
std
::
vector
<
AnfNodePtr
>
anf_nodes
;
std
::
vector
<
AnfNodePtr
>
inputs_list
;
std
::
vector
<
AnfNodePtr
>
outputs_list
;
kernel
::
KernelBuildInfoPtr
kernel_build_info
;
};
class
FusionBasePass
:
public
Pass
{
public:
FusionBasePass
(
const
std
::
string
&
name
,
FusionIdAllocatorPtr
idAllocator
)
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc
浏览文件 @
2215e326
...
...
@@ -15,6 +15,7 @@
*/
#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h"
#include <vector>
#include <algorithm>
#include <unordered_set>
#include <memory>
#include <string>
...
...
@@ -51,7 +52,9 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
if
(
AnfAlgo
::
GetKernelType
(
eltwise_input
)
==
KernelType
::
TBE_KERNEL
&&
AnfAlgo
::
GetFusionType
(
eltwise_input
)
==
kernel
::
FusionType
::
COMMREDUCE
)
{
(
void
)
record
.
insert
(
eltwise_input
);
auto
previous_eltwise_input
=
cnode
->
input
(
1
);
auto
previous_input_cnode
=
eltwise_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
previous_input_cnode
);
auto
previous_eltwise_input
=
previous_input_cnode
->
input
(
1
);
auto
previous_size
=
record
.
size
();
while
(
CheckEltWiseNode
(
manager
.
get
(),
previous_eltwise_input
))
{
(
void
)
record
.
insert
(
previous_eltwise_input
);
...
...
@@ -71,6 +74,7 @@ void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
.
get_return
());
std
::
reverse
(
node_list
.
begin
(),
node_list
.
end
());
for
(
auto
&
node
:
node_list
)
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
)
||
fusion_id_allocator
->
HasFusionIdAttr
(
node
)
||
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimReturn
))
{
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc
浏览文件 @
2215e326
...
...
@@ -51,7 +51,9 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
if
(
AnfAlgo
::
GetKernelType
(
eltwise_input
)
==
KernelType
::
TBE_KERNEL
&&
AnfAlgo
::
GetFusionType
(
eltwise_input
)
==
kernel
::
FusionType
::
SEGMENT
)
{
(
void
)
record
.
insert
(
eltwise_input
);
auto
previous_eltwise_input
=
cnode
->
input
(
1
);
auto
previous_input_cnode
=
eltwise_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
previous_input_cnode
);
auto
previous_eltwise_input
=
previous_input_cnode
->
input
(
1
);
auto
previous_size
=
record
.
size
();
while
(
CheckEltWiseNode
(
manager
.
get
(),
previous_eltwise_input
))
{
(
void
)
record
.
insert
(
previous_eltwise_input
);
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h
浏览文件 @
2215e326
...
...
@@ -19,13 +19,13 @@
#include <unordered_set>
#include <vector>
#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h"
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
#include "pre_activate/common/fusion_id_allocator.h"
#include "device/kernel_info.h"
#include "kernel/kernel.h"
#include "session/kernel_graph.h"
#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h"
namespace
mindspore
{
namespace
opt
{
...
...
mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.cc
已删除
100644 → 0
浏览文件 @
728afd23
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/pass/remove_nop_nodes.h"
#include "common/utils.h"
#include "pre_activate/common/helper.h"
namespace
mindspore
{
namespace
opt
{
const
AnfNodePtr
RemoveNopNodes
::
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
)
const
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
return
nullptr
;
}
CNodePtr
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
!
IsNopNode
(
node
))
{
return
nullptr
;
}
return
cnode
->
input
(
1
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.h
已删除
100644 → 0
浏览文件 @
728afd23
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REMOVE_NOP_NODES_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REMOVE_NOP_NODES_H_
#include "ir/anf.h"
#include "pre_activate/common/optimizer.h"
namespace
mindspore
{
namespace
opt
{
class
RemoveNopNodes
:
public
PatternProcessPass
{
public:
explicit
RemoveNopNodes
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"remove_nop_nodes"
,
multigraph
)
{}
~
RemoveNopNodes
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REMOVE_NOP_NODES_H_
tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
浏览文件 @
2215e326
...
...
@@ -21,7 +21,19 @@
#include "device/kernel_info.h"
#include "pre_activate/common/optimizer.h"
#include "session/anf_runtime_algorithm.h"
#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h"
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h"
#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -79,10 +91,13 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
cast
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder1
.
Build
(),
cast
.
get
());
auto
fusion_id_allocator
=
std
::
make_shared
<
FusionIdAllocator
>
();
MS_EXCEPTION_IF_NULL
(
fusion_id_allocator
);
fusion_id_allocator
->
Init
();
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
buffer_fusion_pass
=
std
::
make_shared
<
opt
::
BufferFusion
>
(
);
pm
->
AddPass
(
buffer_fusion_pass
);
pm
->
AddPass
(
std
::
make_shared
<
EltwiseFusionPass
>
(
fusion_id_allocator
)
);
pm
->
AddPass
(
std
::
make_shared
<
UbPatternFusion
>
()
);
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
...
...
@@ -168,10 +183,13 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
biasadd
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder2
.
Build
(),
biasadd
.
get
());
auto
fusion_id_allocator
=
std
::
make_shared
<
FusionIdAllocator
>
();
MS_EXCEPTION_IF_NULL
(
fusion_id_allocator
);
fusion_id_allocator
->
Init
();
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
buffer_fusion_pass
=
std
::
make_shared
<
opt
::
BufferFusion
>
(
);
pm
->
AddPass
(
buffer_fusion_pass
);
pm
->
AddPass
(
std
::
make_shared
<
ReduceEltwiseFusionPass
>
(
fusion_id_allocator
)
);
pm
->
AddPass
(
std
::
make_shared
<
UbPatternFusion
>
()
);
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
...
...
@@ -255,10 +273,13 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
biasaddgrad
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder2
.
Build
(),
biasaddgrad
.
get
());
auto
fusion_id_allocator
=
std
::
make_shared
<
FusionIdAllocator
>
();
MS_EXCEPTION_IF_NULL
(
fusion_id_allocator
);
fusion_id_allocator
->
Init
();
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
buffer_fusion_pass
=
std
::
make_shared
<
opt
::
BufferFusion
>
(
);
pm
->
AddPass
(
buffer_fusion_pass
);
pm
->
AddPass
(
std
::
make_shared
<
ReduceEltwiseFusionPass
>
(
fusion_id_allocator
)
);
pm
->
AddPass
(
std
::
make_shared
<
UbPatternFusion
>
()
);
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
...
...
@@ -321,10 +342,13 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
cast
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder1
.
Build
(),
cast
.
get
());
auto
fusion_id_allocator
=
std
::
make_shared
<
FusionIdAllocator
>
();
MS_EXCEPTION_IF_NULL
(
fusion_id_allocator
);
fusion_id_allocator
->
Init
();
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
buffer_fusion_pass
=
std
::
make_shared
<
opt
::
BufferFusion
>
(
);
pm
->
AddPass
(
buffer_fusion_pass
);
pm
->
AddPass
(
std
::
make_shared
<
MatmulEltwiseFusionPass
>
(
fusion_id_allocator
)
);
pm
->
AddPass
(
std
::
make_shared
<
UbPatternFusion
>
()
);
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录