Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2e2e7a28
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2e2e7a28
编写于
4月 27, 2020
作者:
E
Etone.Chan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor buffer fusion
上级
da5b10b6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
58 addition
and
90 deletion
+58
-90
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
.../ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
+54
-89
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
...e/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
+0
-1
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+2
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+2
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
浏览文件 @
2e2e7a28
...
@@ -261,23 +261,24 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
...
@@ -261,23 +261,24 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
return
buffer_fusion_kernel
;
return
buffer_fusion_kernel
;
}
}
kernel
::
KernelBuildInfoPtr
CreateFusionOpKernelInfo
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs_list_in
,
kernel
::
KernelBuildInfoPtr
CreateFusionOpKernelInfo
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs_list
,
const
std
::
vector
<
AnfNodePtr
>
&
inputs_list
,
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
)
{
const
std
::
vector
<
AnfNodePtr
>
&
outputs_list
)
{
MS_LOG
(
DEBUG
)
<<
"Start Create Kernel Info"
;
MS_LOG
(
DEBUG
)
<<
"Start Create Kernel Info"
;
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
// inputs format and data type
// inputs format and data type
std
::
vector
<
std
::
string
>
inputs_format
;
std
::
vector
<
std
::
string
>
inputs_format
;
std
::
vector
<
TypeId
>
inputs_data_type
;
std
::
vector
<
TypeId
>
inputs_data_type
;
for
(
auto
node
:
inputs_list_in
)
{
for
(
const
auto
&
input
:
inputs_list
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
input
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
input
)
==
prim
::
kPrimTupleGetItem
->
name
())
{
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
tuple_getitem
=
input
->
cast
<
CNodePtr
>
();
auto
&
inputs
=
cnode
->
inputs
();
MS_EXCEPTION_IF_NULL
(
tuple_getitem
);
for
(
size_t
input_index
=
1
;
input_index
<
inputs
.
size
();
++
input_index
)
{
inputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
if
(
std
::
find
(
inputs_list
.
begin
(),
inputs_list
.
end
(),
inputs
[
input_index
])
!=
inputs_list
.
end
())
{
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
inputs_format
.
push_back
(
AnfAlgo
::
GetInputFormat
(
node
,
input_index
-
1
));
inputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
inputs_data_type
.
push_back
(
AnfAlgo
::
GetInputDeviceDataType
(
node
,
input_index
-
1
));
tuple_getitem
->
input
(
1
),
IntToSize
(
GetValue
<
int
>
(
GetValueNode
(
tuple_getitem
->
input
(
2
))))));
}
}
else
{
inputs_format
.
push_back
(
AnfAlgo
::
GetOutputFormat
(
input
,
0
));
inputs_data_type
.
push_back
(
AnfAlgo
::
GetOutputDeviceDataType
(
input
,
0
));
}
}
}
}
// outputs format and data type
// outputs format and data type
...
@@ -360,62 +361,6 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi
...
@@ -360,62 +361,6 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi
}
}
}
}
void
GetInputList
(
const
CNodePtr
&
node
,
const
int32_t
cur_fusion_id
,
std
::
vector
<
AnfNodePtr
>
*
inputs_list
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
inputs_list
);
auto
&
inputs
=
node
->
inputs
();
for
(
size_t
input_index
=
1
;
input_index
<
inputs
.
size
();
++
input_index
)
{
auto
input
=
inputs
[
input_index
];
if
(
AnfAlgo
::
IsRealCNodeKernel
(
input
))
{
if
(
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
input
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
input
,
kOpAttrFusionId
);
if
(
fusion_id
!=
cur_fusion_id
)
{
inputs_list
->
push_back
(
input
);
}
}
else
{
inputs_list
->
push_back
(
input
);
}
}
else
if
(
input
->
isa
<
CNode
>
())
{
for
(
auto
&
input_in
:
input
->
cast
<
CNodePtr
>
()
->
inputs
())
{
if
(
AnfAlgo
::
IsRealCNodeKernel
(
input_in
))
{
if
(
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
input_in
))
{
auto
fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
input_in
,
kOpAttrFusionId
);
if
(
fusion_id
!=
cur_fusion_id
)
{
inputs_list
->
push_back
(
input
);
}
}
else
{
inputs_list
->
push_back
(
input
);
}
}
}
}
else
{
inputs_list
->
push_back
(
input
);
}
}
}
void
CheckCurrentNodeIsInput
(
const
CNodePtr
&
node
,
const
int32_t
&
cur_fusion_id
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
if
((
*
buffer_fusion_infos
).
find
(
cur_fusion_id
)
==
(
*
buffer_fusion_infos
).
end
())
{
BufferFusionInfo_t
buffer_fusion_info
;
(
*
buffer_fusion_infos
)[
cur_fusion_id
]
=
buffer_fusion_info
;
}
std
::
vector
<
AnfNodePtr
>
inputs_list
;
GetInputList
(
node
,
cur_fusion_id
,
&
inputs_list
);
if
(
!
inputs_list
.
empty
())
{
if
(
!
(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list
.
empty
())
{
(
void
)(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list
.
insert
(
(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list
.
end
(),
inputs_list
.
begin
(),
inputs_list
.
end
());
(
void
)(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list_in
.
insert
(
(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list_in
.
end
(),
node
);
}
else
{
(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list
=
inputs_list
;
(
*
buffer_fusion_infos
)[
cur_fusion_id
].
inputs_list_in
.
push_back
(
node
);
}
}
}
void
GetFusionScopeComputeNodeList
(
session
::
KernelGraph
*
kernel_graph
,
void
GetFusionScopeComputeNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
...
@@ -429,6 +374,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
...
@@ -429,6 +374,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
}
}
}
}
void
GetFusionScopeInputNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
auto
manager
=
kernel_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
auto
fusion_id
=
buffer_fusion_info
.
first
;
auto
fusion_info
=
buffer_fusion_info
.
second
;
for
(
const
auto
&
node
:
fusion_info
.
anf_nodes
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
for
(
size_t
idx
=
1
;
idx
<
cnode
->
inputs
().
size
();
++
idx
)
{
auto
real_input
=
AnfAlgo
::
VisitKernel
(
cnode
->
input
(
idx
),
0
);
if
(
std
::
find
(
fusion_info
.
anf_nodes
.
begin
(),
fusion_info
.
anf_nodes
.
end
(),
real_input
.
first
)
==
fusion_info
.
anf_nodes
.
end
())
{
if
(
std
::
find
((
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
begin
(),
(
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
end
(),
cnode
->
input
(
idx
))
==
(
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
end
())
{
(
*
buffer_fusion_infos
)[
fusion_id
].
inputs_list
.
push_back
(
cnode
->
input
(
idx
));
}
}
}
}
}
}
bool
TupleGetitemNodeCompare
(
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
)
{
MS_EXCEPTION_IF_NULL
(
node1
);
MS_EXCEPTION_IF_NULL
(
node2
);
auto
getitem1
=
node1
->
cast
<
CNodePtr
>
();
auto
getitem2
=
node2
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
getitem1
);
MS_EXCEPTION_IF_NULL
(
getitem2
);
auto
output_idx1
=
GetValue
<
int
>
(
GetValueNode
(
getitem1
->
input
(
2
)));
auto
output_idx2
=
GetValue
<
int
>
(
GetValueNode
(
getitem2
->
input
(
2
)));
return
output_idx1
<
output_idx2
;
}
void
GetFusionScopeOutputNodeList
(
session
::
KernelGraph
*
kernel_graph
,
void
GetFusionScopeOutputNodeList
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
...
@@ -454,14 +438,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
...
@@ -454,14 +438,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
std
::
transform
(
manager
->
node_users
()[
node
].
begin
(),
manager
->
node_users
()[
node
].
end
(),
std
::
transform
(
manager
->
node_users
()[
node
].
begin
(),
manager
->
node_users
()[
node
].
end
(),
std
::
back_inserter
(
tuple_getitem_nodes
),
std
::
back_inserter
(
tuple_getitem_nodes
),
[](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
use_node
)
{
return
use_node
.
first
;
});
[](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
use_node
)
{
return
use_node
.
first
;
});
std
::
sort
(
tuple_getitem_nodes
.
begin
(),
tuple_getitem_nodes
.
end
(),
std
::
sort
(
tuple_getitem_nodes
.
begin
(),
tuple_getitem_nodes
.
end
(),
TupleGetitemNodeCompare
);
[](
const
AnfNodePtr
&
node1
,
const
AnfNodePtr
&
node2
)
{
auto
getitem1
=
node1
->
cast
<
CNodePtr
>
();
auto
getitem2
=
node2
->
cast
<
CNodePtr
>
();
auto
output_idx1
=
GetValue
<
int
>
(
GetValueNode
(
getitem1
->
input
(
2
)));
auto
output_idx2
=
GetValue
<
int
>
(
GetValueNode
(
getitem2
->
input
(
2
)));
return
output_idx1
<
output_idx2
;
});
for
(
auto
getitem
:
tuple_getitem_nodes
)
{
for
(
auto
getitem
:
tuple_getitem_nodes
)
{
auto
getitem_ptr
=
getitem
->
cast
<
CNodePtr
>
();
auto
getitem_ptr
=
getitem
->
cast
<
CNodePtr
>
();
auto
input2
=
getitem_ptr
->
input
(
2
);
auto
input2
=
getitem_ptr
->
input
(
2
);
...
@@ -634,24 +611,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
...
@@ -634,24 +611,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
void
BufferFusion
::
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
void
BufferFusion
::
GetBufferFusionInfo
(
session
::
KernelGraph
*
kernel_graph
,
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
{
std
::
unordered_map
<
int32_t
,
BufferFusionInfo_t
>
*
buffer_fusion_infos
)
const
{
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
MS_EXCEPTION_IF_NULL
(
buffer_fusion_infos
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
node
))
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
AnfAlgo
::
HasNodeAttr
(
kOpAttrFusionId
,
cnode
))
{
auto
cur_fusion_id
=
AnfAlgo
::
GetNodeAttr
<
int32_t
>
(
cnode
,
kOpAttrFusionId
);
CheckCurrentNodeIsInput
(
cnode
,
cur_fusion_id
,
buffer_fusion_infos
);
}
}
GetFusionScopeComputeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeComputeNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeInputNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeOutputNodeList
(
kernel_graph
,
buffer_fusion_infos
);
GetFusionScopeOutputNodeList
(
kernel_graph
,
buffer_fusion_infos
);
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
for
(
auto
&
buffer_fusion_info
:
*
buffer_fusion_infos
)
{
buffer_fusion_info
.
second
.
kernel_build_info
=
buffer_fusion_info
.
second
.
kernel_build_info
=
CreateFusionOpKernelInfo
(
buffer_fusion_info
.
second
.
inputs_list_in
,
buffer_fusion_info
.
second
.
inputs_list
,
CreateFusionOpKernelInfo
(
buffer_fusion_info
.
second
.
inputs_list
,
buffer_fusion_info
.
second
.
outputs_list
);
buffer_fusion_info
.
second
.
outputs_list
);
}
}
}
}
...
...
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h
浏览文件 @
2e2e7a28
...
@@ -30,7 +30,6 @@ namespace opt {
...
@@ -30,7 +30,6 @@ namespace opt {
struct
BufferFusionInfo_t
{
struct
BufferFusionInfo_t
{
std
::
vector
<
AnfNodePtr
>
anf_nodes
;
std
::
vector
<
AnfNodePtr
>
anf_nodes
;
std
::
vector
<
AnfNodePtr
>
inputs_list
;
std
::
vector
<
AnfNodePtr
>
inputs_list
;
std
::
vector
<
AnfNodePtr
>
inputs_list_in
;
std
::
vector
<
AnfNodePtr
>
outputs_list
;
std
::
vector
<
AnfNodePtr
>
outputs_list
;
kernel
::
KernelBuildInfoPtr
kernel_build_info
;
kernel
::
KernelBuildInfoPtr
kernel_build_info
;
};
};
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
2e2e7a28
...
@@ -816,6 +816,8 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
...
@@ -816,6 +816,8 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
MS_EXCEPTION_IF_NULL
(
anf_node
);
MS_EXCEPTION_IF_NULL
(
anf_node
);
static
std
::
map
<
std
::
string
,
std
::
map
<
size_t
,
size_t
>>
spec_node_list
=
{
static
std
::
map
<
std
::
string
,
std
::
map
<
size_t
,
size_t
>>
spec_node_list
=
{
{
prim
::
kPrimConv2DBackpropInput
->
name
(),
{{
0
,
1
},
{
1
,
0
}}},
{
prim
::
kPrimConv2DBackpropInput
->
name
(),
{{
0
,
1
},
{
1
,
0
}}},
{
kFusionOpConv2DBackpropInputReluGradV2Name
,
{{
0
,
1
},
{
1
,
0
},
{
2
,
2
}}},
{
kFusionOpConv2DBackpropInputAddNReluGradV2Name
,
{{
0
,
1
},
{
1
,
0
},
{
2
,
2
},
{
3
,
3
}}},
{
prim
::
kPrimConv2DBackpropFilter
->
name
(),
{{
0
,
1
},
{
1
,
0
}}},
{
prim
::
kPrimConv2DBackpropFilter
->
name
(),
{{
0
,
1
},
{
1
,
0
}}},
{
prim
::
kPrimLogSoftmaxGrad
->
name
(),
{{
0
,
1
},
{
1
,
0
}}},
{
prim
::
kPrimLogSoftmaxGrad
->
name
(),
{{
0
,
1
},
{
1
,
0
}}},
{
prim
::
kPrimLayerNormGrad
->
name
(),
{{
0
,
1
},
{
1
,
0
},
{
2
,
2
},
{
3
,
3
},
{
4
,
4
}}},
{
prim
::
kPrimLayerNormGrad
->
name
(),
{{
0
,
1
},
{
1
,
0
},
{
2
,
2
},
{
3
,
3
},
{
4
,
4
}}},
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
2e2e7a28
...
@@ -122,6 +122,8 @@ constexpr auto kSendOpName = "Send";
...
@@ -122,6 +122,8 @@ constexpr auto kSendOpName = "Send";
constexpr
auto
kRecvOpName
=
"Recv"
;
constexpr
auto
kRecvOpName
=
"Recv"
;
constexpr
auto
kReluV2OpName
=
"ReLUV2"
;
constexpr
auto
kReluV2OpName
=
"ReLUV2"
;
constexpr
auto
kReluGradV2OpName
=
"ReluGradV2"
;
constexpr
auto
kReluGradV2OpName
=
"ReluGradV2"
;
constexpr
auto
kFusionOpConv2DBackpropInputReluGradV2Name
=
"FusionOp_Conv2DBackpropInput_ReluGradV2"
;
constexpr
auto
kFusionOpConv2DBackpropInputAddNReluGradV2Name
=
"FusionOp_Conv2DBackpropInput_AddN_ReluGradV2"
;
// attr key name
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
constexpr
auto
kAttrInputNames
=
"input_names"
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录