Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5365678e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5365678e
编写于
4月 14, 2020
作者:
L
lianliguang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor kernel select
上级
9c9c7091
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
35 addition
and
48 deletion
+35
-48
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+35
-48
未找到文件。
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
5365678e
...
...
@@ -31,12 +31,13 @@ namespace mindspore {
namespace
device
{
namespace
ascend
{
namespace
{
const
float
kWegihtBaseScore
=
1
;
const
float
kFeatureMapBaseScore
=
10
;
enum
MatchCountPriority
:
int
{
MATCH_COUNT_PRIORITY_BEGIN
=
0
,
MATCH_DTYPE_COUNT
=
MATCH_COUNT_PRIORITY_BEGIN
,
MATCH_FORMAT_COUNT
,
MATCH_SPECIAL_FORMAT_COUNT
,
MATCH_5D_FORMAT_COUNT
,
MATCH_OUTPUT_DTYPE_COUNT
,
MATCH_COUNT_PRIORITY_END
};
...
...
@@ -82,13 +83,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
}
return
true
;
};
if
(
AnfAlgo
::
GetCNodeName
(
kernel_node
)
==
"Adam"
)
{
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
AnfAlgo
::
GetPrevNodeOutputFormat
(
kernel_node
,
input_num
-
1
)
!=
kernel_build_info
.
GetInputFormat
(
input_num
-
1
))
{
return
false
;
}
}
if
(
AnfAlgo
::
GetCNodeName
(
kernel_node
)
==
prim
::
kPrimCast
->
name
())
{
return
AnfAlgo
::
GetOutputInferDataType
(
kernel_node
,
0
)
==
kernel_build_info
.
GetOutputDeviceType
(
0
)
&&
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
kernel_node
,
0
)
==
kernel_build_info
.
GetInputDeviceType
(
0
);
...
...
@@ -112,21 +106,7 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn
MS_EXCEPTION_IF_NULL
(
cnode
);
// Check input data type
for
(
size_t
input_index
=
0
;
input_index
<
kernel_build_info
.
GetInputNum
();
++
input_index
)
{
AnfNodePtr
cur_input
=
AnfAlgo
::
GetInputNode
(
cnode
,
input_index
);
MS_EXCEPTION_IF_NULL
(
cur_input
);
TypeId
input_origin_type
;
if
(
cur_input
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
IsParameterWeight
(
cur_input
->
cast
<
ParameterPtr
>
()))
{
// weight
input_origin_type
=
AnfAlgo
::
GetOutputDeviceDataType
(
cur_input
,
0
);
}
else
if
(
cur_input
->
isa
<
ValueNode
>
())
{
input_origin_type
=
AnfAlgo
::
GetOutputDeviceDataType
(
cur_input
,
0
);
}
else
{
// feature map
input_origin_type
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
cnode
,
input_index
);
}
if
(
input_origin_type
==
kTypeUnknown
)
{
continue
;
}
TypeId
input_origin_type
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
cnode
,
input_index
);
if
(
kernel_build_info
.
GetInputDeviceType
(
input_index
)
!=
input_origin_type
)
{
return
false
;
}
...
...
@@ -140,6 +120,29 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn
return
true
;
}
string
GetPriorityMatchFormat
(
const
CNodePtr
&
cnode
)
{
string
priority_matched_format
=
kOpFormat_NC1HWC0
;
bool
is_init
=
false
;
bool
need_change_nd
=
false
;
for
(
size_t
index
=
0
;
index
<
AnfAlgo
::
GetInputTensorNum
(
cnode
);
++
index
)
{
auto
pre_output_format
=
AnfAlgo
::
GetPrevNodeOutputFormat
(
cnode
,
index
);
if
(
AnfAlgo
::
IsFeatureMapInput
(
cnode
,
index
)
&&
kNeedTransFormatSet
.
find
(
pre_output_format
)
!=
kNeedTransFormatSet
.
end
())
{
priority_matched_format
=
!
is_init
?
priority_matched_format
:
pre_output_format
;
is_init
=
true
;
}
// feature map has two or more special format;
if
(
priority_matched_format
!=
pre_output_format
&&
pre_output_format
!=
kOpFormat_DEFAULT
)
{
priority_matched_format
=
kOpFormat_DEFAULT
;
}
auto
input_shape_size
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
cnode
,
index
).
size
();
need_change_nd
=
(
need_change_nd
||
(
input_shape_size
!=
4
&&
input_shape_size
>
1
));
}
if
(
need_change_nd
)
{
priority_matched_format
=
kOpFormat_DEFAULT
;
}
return
priority_matched_format
;
}
/**
* compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
* if equal then next num location
...
...
@@ -172,34 +175,18 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
if
(
cur_kernelinfo_match_counts
->
size
()
<
MATCH_COUNT_PRIORITY_END
)
{
MS_LOG
(
EXCEPTION
)
<<
"Out of range cur_kernelinfo_match_counts "
<<
MATCH_COUNT_PRIORITY_END
;
}
auto
pri_match_format
=
GetPriorityMatchFormat
(
kernel_node
);
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
++
input_index
)
{
AnfNodePtr
input_anf_node
=
AnfAlgo
::
GetInputNode
(
kernel_node
,
input_index
);
MS_EXCEPTION_IF_NULL
(
input_anf_node
);
// if a input parameter is a weight with default format, the input shouldn't participate the judge
if
(
input_anf_node
->
isa
<
Parameter
>
())
{
auto
para
=
input_anf_node
->
cast
<
ParameterPtr
>
();
if
(
AnfAlgo
::
IsParameterWeight
(
para
)
&&
AnfAlgo
::
GetOutputDeviceDataType
(
para
,
0
)
==
kTypeUnknown
)
{
continue
;
}
}
auto
base_score
=
AnfAlgo
::
IsFeatureMapInput
(
kernel_node
,
input_index
)
?
kFeatureMapBaseScore
:
kWegihtBaseScore
;
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
AnfAlgo
::
GetPrevNodeOutputFormat
(
kernel_node
,
input_index
))
{
if
(
AnfAlgo
::
IsFeatureMapInput
(
kernel_node
,
input_index
)
&&
kNeedTransFormatSet
.
find
(
kernel_build_info
.
GetInputFormat
(
input_index
))
!=
kNeedTransFormatSet
.
end
())
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_SPECIAL_FORMAT_COUNT
]
++
;
}
(
*
cur_kernelinfo_match_counts
)[
MATCH_FORMAT_COUNT
]
++
;
(
*
cur_kernelinfo_match_counts
)[
MATCH_FORMAT_COUNT
]
+=
base_score
;
}
if
(
kernel_build_info
.
GetInputDeviceType
(
input_index
)
==
AnfAlgo
::
GetPrevNodeOutputDeviceDataType
(
kernel_node
,
input_index
))
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_DTYPE_COUNT
]
++
;
(
*
cur_kernelinfo_match_counts
)[
MATCH_DTYPE_COUNT
]
+=
base_score
;
}
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
kOpFormat_NC1HWC0
)
{
// input is from a feature map & this input's shape is not 4d
if
(
AnfAlgo
::
IsFeatureMapInput
(
kernel_node
,
input_index
)
&&
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
input_index
).
size
()
!=
kShape4dDims
)
{
continue
;
}
(
*
cur_kernelinfo_match_counts
)[
MATCH_5D_FORMAT_COUNT
]
++
;
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
pri_match_format
)
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_SPECIAL_FORMAT_COUNT
]
+=
base_score
;
}
}
...
...
@@ -207,7 +194,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
// cal count of same output dtype between abstract and kernel info
if
(
kernel_build_info
.
GetOutputDeviceType
(
output_index
)
==
AnfAlgo
::
GetOutputInferDataType
(
kernel_node
,
output_index
))
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_OUTPUT_DTYPE_COUNT
]
++
;
(
*
cur_kernelinfo_match_counts
)[
MATCH_OUTPUT_DTYPE_COUNT
]
+=
1
;
}
}
}
...
...
@@ -517,7 +504,7 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
MS_EXCEPTION_IF_NULL
(
kernel_node
);
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
};
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
};
int
selected_index
=
-
1
;
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
...
...
@@ -527,7 +514,7 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
std
::
vector
<
int
>
node_mix_precision_datatype_index
;
std
::
vector
<
TypeId
>
node_mix_precision_datatype
;
for
(
size_t
info_index
=
0
;
info_index
<
kernel_info_list
.
size
();
++
info_index
)
{
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
};
auto
kernel_build_info
=
*
(
kernel_info_list
[
info_index
]);
if
(
!
IsValidKernelInfo
(
kernel_node
,
kernel_build_info
))
{
continue
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录