Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
38e0d98e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
38e0d98e
编写于
5月 07, 2020
作者:
E
etone-chan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor fusion id implement of buffer fusion
上级
67013077
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
135 addition
and
48 deletion
+135
-48
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
.../ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
+32
-47
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
...e/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
+14
-1
mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc
mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc
+46
-0
mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h
mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h
+42
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
浏览文件 @
38e0d98e
...
...
@@ -32,6 +32,7 @@
#include "operator/ops.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#include "pre_activate/common/fusion_id_allocator.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -79,20 +80,6 @@ void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) {
}
#endif
void
SetAnfNodeFusionId
(
const
FusedNodeRecord
&
record_node
)
{
MS_LOG
(
DEBUG
)
<<
"Size of opt vector to be fused is "
<<
record_node
.
size
();
int32_t
id
=
1
;
for
(
auto
&
record
:
record_node
)
{
MS_LOG
(
DEBUG
)
<<
"No"
<<
id
<<
", opt vector to be fused contain "
<<
record
.
size
()
<<
" opt."
;
for
(
const
auto
&
candidate
:
record
)
{
ValuePtr
fusion_id_v
=
MakeValue
(
id
);
AnfAlgo
::
SetNodeAttr
(
kOpAttrFusionId
,
fusion_id_v
,
candidate
);
MS_LOG
(
DEBUG
)
<<
"No "
<<
id
<<
": "
<<
candidate
->
DebugString
();
}
id
++
;
}
}
bool
CheckEltWiseNode
(
FuncGraphManager
*
manager
,
std
::
unordered_set
<
AnfNodePtr
>
*
record
,
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
record
);
...
...
@@ -482,11 +469,18 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
}
}
}
}
// namespace
void
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
void
BufferFusion
::
SetRecordFusionId
(
const
std
::
unordered_set
<
AnfNodePtr
>
&
record
)
{
auto
id
=
fusion_id_allocator
.
AllocateFusionId
();
for
(
auto
node
:
record
)
{
fusion_id_allocator
.
SetFusionId
(
node
,
id
);
}
}
void
BufferFusion
::
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
...
...
@@ -496,14 +490,13 @@ void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
conv
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
conv
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
()
);
SetRecordFusionId
(
record
);
}
}
void
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
void
BufferFusion
::
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
...
...
@@ -520,14 +513,13 @@ void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, cons
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
()
);
SetRecordFusionId
(
record
);
}
}
void
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
void
BufferFusion
::
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
...
...
@@ -548,41 +540,37 @@ void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, c
AnfAlgo
::
SetNodeAttr
(
kAttrOutputUsedNum
,
MakeValue
(
output_used_num
),
bnupdate
);
std
::
unordered_set
<
AnfNodePtr
>
record
{
cnode
,
relu_input
,
bnupdate
};
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
()
);
SetRecordFusionId
(
record
);
}
}
}
void
MatchOpNamePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
fused_set
);
void
BufferFusion
::
MatchOpNamePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
.
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
)
||
fus
ed_set
->
find
(
node
)
!=
fused_set
->
end
(
))
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
)
||
fus
ion_id_allocator
.
HasFusionIdAttr
(
node
))
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kBNTrainingReduceOpName
)
{
MatchConvBnreduce
(
cnode
,
kernel_graph
,
fused_set
,
candidate_fusion
);
MatchConvBnreduce
(
cnode
,
kernel_graph
,
candidate_fusion
);
}
else
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
kReluV2OpName
||
AnfAlgo
::
GetCNodeName
(
cnode
)
==
prim
::
kPrimRelu
->
name
())
{
auto
relu_input
=
cnode
->
input
(
1
);
if
(
relu_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
relu_input
)
==
prim
::
kPrimTensorAdd
->
name
())
{
MatchBnupdateAddRelu
(
cnode
,
relu_input
,
kernel_graph
,
fused_set
,
candidate_fusion
);
MatchBnupdateAddRelu
(
cnode
,
relu_input
,
kernel_graph
,
candidate_fusion
);
}
else
if
(
relu_input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
relu_input
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
MatchBnupdateRelu
(
cnode
,
relu_input
,
kernel_graph
,
fused_set
,
candidate_fusion
);
MatchBnupdateRelu
(
cnode
,
relu_input
,
kernel_graph
,
candidate_fusion
);
}
}
}
}
void
MatchFusionTypePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
std
::
unordered_set
<
AnfNodePtr
>
*
fused_set
,
FusedNodeRecord
*
candidate_fusion
)
{
void
BufferFusion
::
MatchFusionTypePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
)
{
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
MS_EXCEPTION_IF_NULL
(
fused_set
);
MS_EXCEPTION_IF_NULL
(
candidate_fusion
);
auto
return_node
=
kernel_graph
.
get_return
();
...
...
@@ -599,7 +587,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
MS_EXCEPTION_IF_NULL
(
node
);
todo
.
pop_front
();
std
::
unordered_set
<
AnfNodePtr
>
record
;
if
(
visited_set
.
find
(
node
)
!=
visited_set
.
end
()
||
fus
ed_set
->
find
(
node
)
!=
fused_set
->
end
(
))
{
if
(
visited_set
.
find
(
node
)
!=
visited_set
.
end
()
||
fus
ion_id_allocator
.
HasFusionIdAttr
(
node
))
{
continue
;
}
// Only fuse real cnode
...
...
@@ -616,7 +604,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
cnode
=
FindFusionAnfNode
(
manager
.
get
(),
&
visited_set
,
&
record
,
&
todo
,
cnode
);
if
(
record
.
size
()
>=
MIN_PATTERN_SIZE
&&
record
.
size
()
<=
MAX_PATTERN_SIZE
)
{
candidate_fusion
->
push_back
(
record
);
fused_set
->
insert
(
record
.
begin
(),
record
.
end
()
);
SetRecordFusionId
(
record
);
}
if
(
record
.
find
(
cnode
)
==
record
.
end
())
{
todo
.
push_back
(
cnode
);
...
...
@@ -628,7 +616,6 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
(
void
)
todo
.
insert
(
todo
.
end
(),
cnode
->
inputs
().
begin
()
+
1
,
cnode
->
inputs
().
end
());
}
}
}
// namespace
void
BufferFusion
::
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
{
...
...
@@ -684,7 +671,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
return
change
;
}
bool
BufferFusion
::
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
)
const
{
bool
BufferFusion
::
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
)
{
auto
manager
=
kernel_graph
.
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
auto
return_node
=
kernel_graph
.
get_return
();
...
...
@@ -694,14 +681,11 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g
}
MS_LOG
(
DEBUG
)
<<
"MatchBufferFusionPattern start..."
;
FusedNodeRecord
candidate_fusion
;
std
::
unordered_set
<
AnfNodePtr
>
fused_set
;
MatchOpNamePattern
(
kernel_graph
,
&
fused_set
,
&
candidate_fusion
);
MatchFusionTypePattern
(
kernel_graph
,
&
fused_set
,
&
candidate_fusion
);
MatchOpNamePattern
(
kernel_graph
,
&
candidate_fusion
);
MatchFusionTypePattern
(
kernel_graph
,
&
candidate_fusion
);
if
(
!
candidate_fusion
.
empty
())
{
SetAnfNodeFusionId
(
candidate_fusion
);
}
else
{
if
(
candidate_fusion
.
empty
())
{
return
false
;
}
MS_LOG
(
DEBUG
)
<<
"MatchBufferFusionPattern Success..."
;
...
...
@@ -741,13 +725,14 @@ bool BufferFusion::Run(const FuncGraphPtr &graph) {
auto
kernel_graph
=
graph
->
cast
<
std
::
shared_ptr
<
session
::
KernelGraph
>>
();
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
fusion_id_allocator
.
Init
();
if
(
MatchBufferFusionPattern
(
*
kernel_graph
))
{
changed
=
FuseBufferFusionPattern
(
kernel_graph
.
get
());
}
// clear fusion_id attr
for
(
auto
&
node
:
graph
->
nodes
())
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
())
{
AnfAlgo
::
EraseNodeAttr
(
k
Op
AttrFusionId
,
node
);
AnfAlgo
::
EraseNodeAttr
(
kAttrFusionId
,
node
);
}
}
return
changed
;
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
浏览文件 @
38e0d98e
...
...
@@ -21,6 +21,7 @@
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
#include "pre_activate/common/fusion_id_allocator.h"
#include "device/kernel_info.h"
#include "kernel/kernel.h"
#include "session/kernel_graph.h"
...
...
@@ -43,12 +44,24 @@ class BufferFusion : public Pass {
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
void
SetRecordFusionId
(
const
std
::
unordered_set
<
AnfNodePtr
>
&
record
);
void
MatchConvBnreduce
(
const
CNodePtr
&
cnode
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchBnupdateRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchBnupdateAddRelu
(
const
CNodePtr
&
cnode
,
const
AnfNodePtr
&
relu_input
,
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchOpNamePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
MatchFusionTypePattern
(
const
session
::
KernelGraph
&
kernel_graph
,
FusedNodeRecord
*
candidate_fusion
);
void
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
;
bool
ReplaceFusionOp
(
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
,
int32_t
fusion_id
,
const
kernel
::
KernelModPtr
&
kernel_ptr
,
session
::
KernelGraph
*
kernel_graph
)
const
;
bool
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
)
const
;
bool
MatchBufferFusionPattern
(
const
session
::
KernelGraph
&
kernel_graph
);
bool
FuseBufferFusionPattern
(
session
::
KernelGraph
*
kernel_graph
)
const
;
FusionIdAllocator
fusion_id_allocator
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc
0 → 100644
浏览文件 @
38e0d98e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/common/fusion_id_allocator.h"
#include "session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
FusionIdAllocator
::
FusionIdAllocator
()
{
fusion_id
=
0
;
}
FusionIdAllocator
::~
FusionIdAllocator
()
{}
void
FusionIdAllocator
::
Init
()
{
fusion_id
=
0
;
}
int32_t
FusionIdAllocator
::
AllocateFusionId
()
{
fusion_id
++
;
return
fusion_id
;
}
bool
FusionIdAllocator
::
HasFusionIdAttr
(
const
AnfNodePtr
&
node
)
{
return
AnfAlgo
::
HasNodeAttr
(
kAttrFusionId
,
node
);
}
int32_t
FusionIdAllocator
::
GetFusionId
(
const
AnfNodePtr
&
node
)
{
if
(
HasFusionIdAttr
(
node
))
{
return
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
node
,
kAttrFusionId
);
}
return
-
1
;
}
void
FusionIdAllocator
::
SetFusionId
(
const
AnfNodePtr
&
node
,
int32_t
id
)
{
ValuePtr
fusion_id_v
=
MakeValue
(
id
);
AnfAlgo
::
SetNodeAttr
(
kAttrFusionId
,
fusion_id_v
,
node
);
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h
0 → 100644
浏览文件 @
38e0d98e
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
#include "ir/base.h"
namespace
mindspore
{
namespace
opt
{
class
FusionIdAllocator
{
public:
FusionIdAllocator
();
virtual
~
FusionIdAllocator
();
FusionIdAllocator
(
const
FusionIdAllocator
&
in
)
=
delete
;
FusionIdAllocator
&
operator
=
(
const
FusionIdAllocator
&
in
)
=
delete
;
void
Init
();
int32_t
AllocateFusionId
();
bool
HasFusionIdAttr
(
const
AnfNodePtr
&
node
);
int32_t
GetFusionId
(
const
AnfNodePtr
&
node
);
void
SetFusionId
(
const
AnfNodePtr
&
node
,
int32_t
id
);
private:
int32_t
fusion_id
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
mindspore/ccsrc/utils/utils.h
浏览文件 @
38e0d98e
...
...
@@ -165,6 +165,7 @@ constexpr auto kAttrFusion = "fusion";
constexpr
auto
kAttrGroup
=
"group"
;
constexpr
auto
kAttrOp
=
"op"
;
constexpr
auto
kAttrIsTraining
=
"is_training"
;
constexpr
auto
kAttrFusionId
=
"fusion_id"
;
// attr value
constexpr
auto
kValueTargetSwitch
=
"target_switch"
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录