Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
fb5cfe31
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fb5cfe31
编写于
4月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!430 add buffer fusion bnupdate eltwise pass
Merge pull request !430 from Etone.Chan/bufferfusion
上级
e00f1736
d2727d05
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
182 addition
and
1388 deletion
+182
-1388
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
+1
-2
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
.../ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
+178
-85
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
...e/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
+3
-3
tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
...p/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
+0
-1298
未找到文件。
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
浏览文件 @
fb5cfe31
...
...
@@ -722,8 +722,7 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode
std
::
vector
<
nlohmann
::
json
>
*
output_desc_list
)
{
auto
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
cnode
);
if
(
AnfAlgo
::
HasNodeAttr
(
kAttrOutputUsedNum
,
cnode
))
{
// wait anther pr: auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrOutputUsedNum);
auto
output_used_nums
=
{
SizeToInt
(
AnfAlgo
::
GetNodeAttr
<
std
::
size_t
>
(
cnode
,
kAttrOutputUsedNum
))};
auto
output_used_nums
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
int
>>
(
cnode
,
kAttrOutputUsedNum
);
MS_LOG
(
INFO
)
<<
"This node's output has been reused, node name: "
<<
cnode
->
fullname_with_scope
();
if
(
output_used_nums
.
size
()
!=
output_size
)
{
MS_LOG
(
INFO
)
<<
"Fusion error: output tenor num("
<<
output_size
<<
")"
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
浏览文件 @
fb5cfe31
...
...
@@ -17,6 +17,7 @@
#include <vector>
#include <tuple>
#include <utility>
#include <unordered_set>
#include <unordered_map>
#include <deque>
...
...
@@ -282,11 +283,17 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr
// outputs format and data type
std
::
vector
<
std
::
string
>
outputs_format
;
std
::
vector
<
TypeId
>
outputs_data_type
;
for
(
size_t
index
=
0
;
index
<
outputs_list
.
size
();
++
index
)
{
for
(
size_t
idx
=
0
;
idx
<
AnfAlgo
::
GetOutputTensorNum
(
outputs_list
[
index
]);
++
idx
)
{
auto
kernel_with_index
=
AnfAlgo
::
VisitKernel
(
outputs_list
[
index
],
idx
);
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
kernel_with_index
.
first
,
kernel_with_index
.
second
));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
kernel_with_index
.
first
,
kernel_with_index
.
second
));
for
(
const
auto
&
output
:
outputs_list
)
{
if
(
AnfAlgo
::
GetCNodeName
(
output
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
tuple_getitem
=
output
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
}
else
{
outputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
output
,
0
));
outputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
output
,
0
));
}
}
builder
.
SetInputsFormat
(
inputs_format
);
...
...
@@ -320,32 +327,35 @@ AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::K
return
tuple_item
;
}
void
ReplaceOldNode
(
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
,
const
AnfNodePtr
&
buffer_fusion_kernel
,
session
::
KernelGraph
*
kernel_graph
)
{
void
ReplaceInputNodeInOtherFusionScope
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
AnfNodePtr
&
output_item
,
const
AnfNodePtr
&
replace_item
)
{
for
(
int32_t
id
=
fusion_id
+
1
;
id
<=
SizeToInt
(
buffer_fusion_infos
->
size
());
++
id
)
{
auto
itr
=
std
::
find
((
*
buffer_fusion_infos
)[
id
].
inputs_list
.
begin
(),
(
*
buffer_fusion_infos
)[
id
].
inputs_list
.
end
(),
output_item
);
if
(
itr
!=
(
*
buffer_fusion_infos
)[
id
].
inputs_list
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"replace input of other pattern, id = "
<<
id
;
*
itr
=
replace_item
;
}
}
}
void
ReplaceOldNode
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
AnfNodePtr
&
buffer_fusion_kernel
,
session
::
KernelGraph
*
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
if
(
outputs_list
.
size
()
==
1
)
{
// single output
(
void
)
manager
->
Replace
(
outputs_list
[
0
],
buffer_fusion_kernel
);
auto
buffer_fusion_info
=
(
*
buffer_fusion_infos
)[
fusion_id
];
if
(
buffer_fusion_info
.
outputs_list
.
size
()
==
1
)
{
// single output
(
void
)
manager
->
Replace
(
buffer_fusion_info
.
outputs_list
[
0
],
buffer_fusion_kernel
);
ReplaceInputNodeInOtherFusionScope
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion_info
.
outputs_list
[
0
],
buffer_fusion_kernel
);
}
else
{
// multiple output
size_t
real_idx
=
0
;
for
(
size_t
index
=
0
;
index
<
outputs_list
.
size
();
++
index
)
{
if
(
AnfAlgo
::
GetOutputTensorNum
(
outputs_list
[
index
])
==
1
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
real_idx
++
);
(
void
)
manager
->
Replace
(
outputs_list
[
index
],
tuple_item
);
}
else
{
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
AbstractBasePtrList
abstract_list
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
size_t
idx
=
0
;
idx
<
AnfAlgo
::
GetOutputTensorNum
(
outputs_list
[
index
]);
++
idx
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
real_idx
++
);
abstract_list
.
push_back
(
tuple_item
->
abstract
());
make_tuple_inputs
.
push_back
(
tuple_item
);
}
AnfNodePtr
make_tuple
=
kernel_graph
->
NewCNode
(
make_tuple_inputs
);
make_tuple
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
));
(
void
)
manager
->
Replace
(
outputs_list
[
index
],
make_tuple
);
}
for
(
size_t
index
=
0
;
index
<
buffer_fusion_info
.
outputs_list
.
size
();
++
index
)
{
auto
tuple_item
=
CreateTupleGetItem
(
buffer_fusion_kernel
,
kernel_graph
,
index
);
(
void
)
manager
->
Replace
(
buffer_fusion_info
.
outputs_list
[
index
],
tuple_item
);
ReplaceInputNodeInOtherFusionScope
(
buffer_fusion_infos
,
fusion_id
,
buffer_fusion_info
.
outputs_list
[
index
],
tuple_item
);
}
}
}
...
...
@@ -406,38 +416,67 @@ void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id,
}
}
void
InsertNode
(
const
AnfNodePtr
&
node
,
std
::
vector
<
AnfNodePtr
>
*
list
)
{
MS_EXCEPTION_IF_NULL
(
list
);
if
(
std
::
find
(
list
->
begin
(),
list
->
end
(),
node
)
==
list
->
end
())
{
(
void
)
list
->
insert
(
list
->
end
(),
node
);
void
GetFusionScopeComputeNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
nodes
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
AnfAlgo
::
IsRealCNodeKernel
(
node
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
node
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
node
,
kOpAttrFusionId
);
(
*
buffer_fusion_infos
)[
fusion_id
].
anf_nodes
.
push_back
(
node
);
}
}
}
void
CheckCurrentNodeIsOutput
(
const
CNodePtr
&
node
,
const
int32_t
&
cur_fusion_id
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
node
);
void
GetFusionScopeOutputNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
for
(
auto
&
input
:
node
->
inputs
())
{
MS_EXCEPTION_IF_NULL
(
input
);
if
(
AnfAlgo
::
IsRealCNodeKernel
(
input
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
input
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
input
,
kOpAttrFusionId
);
if
(
buffer_fusion_infos
->
find
(
fusion_id
)
==
buffer_fusion_infos
->
end
())
{
BufferFusionInfo_t
buffer_fusion_info
;
(
*
buffer_fusion_infos
)[
fusion_id
]
=
buffer_fusion_info
;
}
if
(
fusion_id
!=
cur_fusion_id
)
{
InsertNode
(
input
,
&
((
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
));
}
}
else
if
(
input
->
isa
<
CNode
>
())
{
for
(
auto
&
input_in
:
input
->
cast
<
CNodePtr
>
()
->
inputs
())
{
if
(
AnfAlgo
::
IsRealCNodeKernel
(
input_in
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
input_in
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
input_in
,
kOpAttrFusionId
);
if
(
buffer_fusion_infos
->
find
(
fusion_id
)
==
buffer_fusion_infos
->
end
())
{
BufferFusionInfo_t
buffer_fusion_info
;
(
*
buffer_fusion_infos
)[
fusion_id
]
=
buffer_fusion_info
;
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
auto
fusion_id
=
buffer_fusion_info
.
first
;
auto
fusion_info
=
buffer_fusion_info
.
second
;
for
(
const
auto
&
node
:
fusion_info
.
anf_nodes
)
{
if
(
AnfAlgo
::
GetOutputTensorNum
(
node
)
==
1
)
{
for
(
auto
use_node
:
manager
->
node_users
()[
node
])
{
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
use_node
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
node
);
break
;
}
}
}
else
{
int
prev_idx
=
0
;
std
::
vector
<
AnfNodePtr
>
tuple_getitem_nodes
;
std
::
transform
(
manager
->
node_users
()[
node
].
begin
(),
manager
->
node_users
()[
node
].
end
(),
std
::
back_inserter
(
tuple_getitem_nodes
),
[](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
use_node
)
{
return
use_node
.
first
;
});
std
::
sort
(
tuple_getitem_nodes
.
begin
(),
tuple_getitem_nodes
.
end
(),
[](
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
)
{
auto
getitem1
=
node1
->
cast
<
CNodePtr
>
();
auto
getitem2
=
node2
->
cast
<
CNodePtr
>
();
auto
output_idx1
=
GetValue
<
int
>
(
GetValueNode
(
getitem1
->
input
(
2
)));
auto
output_idx2
=
GetValue
<
int
>
(
GetValueNode
(
getitem2
->
input
(
2
)));
return
output_idx1
<
output_idx2
;
});
for
(
auto
getitem
:
tuple_getitem_nodes
)
{
auto
getitem_ptr
=
getitem
->
cast
<
CNodePtr
>
();
auto
input2
=
getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
for
(
int
stub_idx
=
prev_idx
;
stub_idx
<
output_idx
;
++
stub_idx
)
{
auto
stub_node
=
CreateTupleGetItem
(
node
,
kernel_graph
,
IntToSize
(
stub_idx
));
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
stub_node
);
}
if
(
fusion_id
!=
cur_fusion_id
)
{
InsertNode
(
input_in
,
&
((
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
));
prev_idx
=
output_idx
+
1
;
for
(
auto
item_use_node
:
manager
->
node_users
()[
getitem
])
{
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
item_use_node
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
outputs_list
.
push_back
(
getitem
);
break
;
}
}
}
}
...
...
@@ -445,15 +484,72 @@ void CheckCurrentNodeIsOutput(const CNodePtr &node, const int32_t &cur_fusion_id
}
}
void
GetFusionScopeNodeList
(
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
nodes
=
TopoSort
(
kernel_graph
.
get_return
());
for
(
auto
&
node
:
nodes
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
AnfAlgo
::
IsRealCNodeKernel
(
node
)
&&
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
node
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
node
,
kOpAttrFusionId
);
(
*
buffer_fusion_infos
)[
fusion_id
].
anf_nodes
.
push_back
(
node
);
void
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
conv
=
cnode
->
input
(
1
);
if
(
conv
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
conv
)
==
prim
::
kPrimConv2D
->
name
())
{
std
::
vector
<
int
>
output_used_num
{
SizeToInt
(
manager
->
node_users
()[
conv
].
size
())};
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
conv
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
}
}
void
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
getitem
=
relu_input
->
cast
<
CNodePtr
>
();
auto
bnupdate
=
getitem
->
input
(
1
);
if
(
bnupdate
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
bnupdate
)
==
kBNTrainingUpdateOpName
)
{
std
::
vector
<
int
>
output_used_num
(
AnfAlgo
::
GetOutputTensorNum
(
bnupdate
),
0
);
for
(
auto
out_getitem
:
manager
->
node_users
()[
bnupdate
])
{
auto
out_getitem_ptr
=
out_getitem
.
first
->
cast
<
CNodePtr
>
();
auto
input2
=
out_getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
output_used_num
[
output_idx
]
=
SizeToInt
(
manager
->
node_users
()[
out_getitem
.
first
].
size
());
}
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
}
}
void
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
add
=
relu_input
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
add
);
auto
tuple_getitem
=
add
->
input
(
1
);
if
(
tuple_getitem
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
tuple_getitem
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
auto
getitem
=
tuple_getitem
->
cast
<
CNodePtr
>
();
auto
bnupdate
=
getitem
->
input
(
1
);
if
(
bnupdate
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
bnupdate
)
==
kBNTrainingUpdateOpName
)
{
std
::
vector
<
int
>
output_used_num
(
AnfAlgo
::
GetOutputTensorNum
(
bnupdate
),
0
);
for
(
auto
out_getitem
:
manager
->
node_users
()[
bnupdate
])
{
auto
out_getitem_ptr
=
out_getitem
.
first
->
cast
<
CNodePtr
>
();
auto
input2
=
out_getitem_ptr
->
input
(
2
);
auto
output_idx
=
GetValue
<
int
>
(
GetValueNode
(
input2
));
output_used_num
[
output_idx
]
=
SizeToInt
(
manager
->
node_users
()[
out_getitem
.
first
].
size
());
}
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
relu_input
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
}
}
}
...
...
@@ -470,15 +566,14 @@ void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kBNTrainingReduceOpName
)
{
auto
conv
=
cnode
->
input
(
1
);
if
(
conv
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
conv
)
==
prim
::
kPrimConv2D
->
name
())
{
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
&
users
=
manager
->
node_users
();
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
users
[
conv
].
size
()),
conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
({
cnode
,
conv
});
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
());
MatchConvBnreduce
(
cnode
,
kernel_graph
,
fused_set
,
candidate_fusion
);
}
else
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kReluV2OpName
||
AnfAlgo
::
GetCNodeName
(
cnode
)
==
prim
::
kPrimRelu
->
name
())
{
auto
relu_input
=
cnode
->
input
(
1
);
if
(
relu_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
relu_input
)
==
prim
::
kPrimTensorAdd
->
name
())
{
MatchBnupdateAddRelu
(
cnode
,
relu_input
,
kernel_graph
,
fused_set
,
candidate_fusion
);
}
else
if
(
relu_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
relu_input
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
MatchBnupdateRelu
(
cnode
,
relu_input
,
kernel_graph
,
fused_set
,
candidate_fusion
);
}
}
}
...
...
@@ -536,27 +631,23 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
}
}
// namespace
void
BufferFusion
::
GetBufferFusionInfo
(
const
session
::
KernelGraph
&
kernel_graph
,
void
BufferFusion
::
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
.
get_return
());
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
))
{
continue
;
}
int32_t
cur_fusion_id
=
-
1
;
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
cnode
))
{
cur_fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
cnode
,
kOpAttrFusionId
);
auto
cur_fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
cnode
,
kOpAttrFusionId
);
CheckCurrentNodeIsInput
(
cnode
,
cur_fusion_id
,
buffer_fusion_infos
);
}
// Check if current node is output
CheckCurrentNodeIsOutput
(
cnode
,
cur_fusion_id
,
buffer_fusion_infos
);
}
GetFusionScopeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeComputeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScope
Output
NodeList
(
kernel_graph
,
buffer_fusion_infos
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
buffer_fusion_info
.
second
.
kernel_build_info
=
CreateFusionOpKernelInfo
(
buffer_fusion_info
.
second
.
inputs_list_in
,
buffer_fusion_info
.
second
.
inputs_list
,
...
...
@@ -569,7 +660,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
bool
change
=
false
;
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
buffer_fusion_infos
;
buffer_fusion_infos
.
clear
();
GetBufferFusionInfo
(
*
kernel_graph
,
&
buffer_fusion_infos
);
GetBufferFusionInfo
(
kernel_graph
,
&
buffer_fusion_infos
);
std
::
vector
<
mindspore
::
kernel
::
FusionScopeInfo
>
fusion_scope_infos
;
for
(
auto
&
buffer_fusion_info
:
buffer_fusion_infos
)
{
...
...
@@ -600,7 +691,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
MS_LOG
(
DEBUG
)
<<
"fusion id: "
<<
fusion_id
<<
", fusion op compiling failed"
;
continue
;
}
change
=
ReplaceFusionOp
(
buffer_fusion_infos
[
fusion_id
]
,
kernel_mods
[
fusion_id
],
kernel_graph
);
change
=
ReplaceFusionOp
(
&
buffer_fusion_infos
,
fusion_id
,
kernel_mods
[
fusion_id
],
kernel_graph
);
}
MS_LOG
(
DEBUG
)
<<
"End Buffer Fusion"
;
return
change
;
...
...
@@ -630,8 +721,10 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g
return
true
;
}
bool
BufferFusion
::
ReplaceFusionOp
(
const
BufferFusionInfo_t
&
buffer_fusion_info
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
bool
BufferFusion
::
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
{
auto
buffer_fusion_info
=
(
*
buffer_fusion_infos
)[
fusion_id
];
auto
buffer_fusion
=
CreateFusionOp
(
buffer_fusion_info
.
inputs_list
,
buffer_fusion_info
.
outputs_list
,
buffer_fusion_info
.
anf_nodes
,
kernel_graph
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
buffer_fusion_info
.
kernel_build_info
,
buffer_fusion
.
get
());
...
...
@@ -651,7 +744,7 @@ bool BufferFusion::ReplaceFusionOp(const BufferFusionInfo_t &buffer_fusion_info,
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
buffer_fusion
.
get
());
AnfAlgo
::
SetKernelMod
(
kernel_ptr
,
buffer_fusion
.
get
());
// replace node
ReplaceOldNode
(
buffer_fusion_info
.
outputs_list
,
buffer_fusion
,
kernel_graph
);
ReplaceOldNode
(
buffer_fusion_info
s
,
fusion_id
,
buffer_fusion
,
kernel_graph
);
return
true
;
}
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
浏览文件 @
fb5cfe31
...
...
@@ -44,10 +44,10 @@ class BufferFusion : public Pass {
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
void
GetBufferFusionInfo
(
const
session
::
KernelGraph
&
kernel_graph
,
void
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
;
bool
ReplaceFusionOp
(
const
BufferFusionInfo_t
&
buffer_fusion_info
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
;
bool
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
;
bool
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
)
const
;
bool
FuseBufferFusionPattern
(
session
::
KernelGraph
*
kernel_graph
)
const
;
};
...
...
tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc
已删除
100644 → 0
浏览文件 @
e00f1736
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录