Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a04e8486
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a04e8486
编写于
4月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!650 Match format when kernel selecting using raise or reduce precision
Merge pull request !650 from liubuyu/r0.2
上级
16ac0f29
05e001fc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
138 addition
and
145 deletion
+138
-145
mindspore/ccsrc/debug/anf_ir_dump.cc
mindspore/ccsrc/debug/anf_ir_dump.cc
+8
-0
mindspore/ccsrc/debug/anf_ir_dump.h
mindspore/ccsrc/debug/anf_ir_dump.h
+3
-1
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+124
-142
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+3
-2
未找到文件。
mindspore/ccsrc/debug/anf_ir_dump.cc
浏览文件 @
a04e8486
...
@@ -91,6 +91,14 @@ void PrintNodeInputType(std::ostringstream &buffer, const AnfNodePtr &nd) {
...
@@ -91,6 +91,14 @@ void PrintNodeInputType(std::ostringstream &buffer, const AnfNodePtr &nd) {
}
}
}
}
void
PrintInputAndOutputInferType
(
std
::
ostringstream
&
buffer
,
const
AnfNodePtr
&
nd
)
{
buffer
<<
" : ("
;
PrintNodeInputType
(
buffer
,
nd
);
buffer
<<
") -> ("
;
PrintNodeOutputType
(
buffer
,
nd
);
buffer
<<
")"
;
}
struct
SubGraphIRInfo
{
struct
SubGraphIRInfo
{
int32_t
local_var
;
int32_t
local_var
;
std
::
ostringstream
buffer
;
std
::
ostringstream
buffer
;
...
...
mindspore/ccsrc/debug/anf_ir_dump.h
浏览文件 @
a04e8486
...
@@ -18,12 +18,14 @@
...
@@ -18,12 +18,14 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "ir/dtype/type.h"
#include "ir/anf.h"
#include "ir/anf.h"
namespace
mindspore
{
namespace
mindspore
{
constexpr
char
PARALLEL_STRATEGY
[]
=
"strategy"
;
constexpr
char
PARALLEL_STRATEGY
[]
=
"strategy"
;
void
DumpIR
(
const
std
::
string
&
filename
,
const
FuncGraphPtr
&
func_graph
,
bool
dump_full_name
=
false
);
void
DumpIR
(
const
std
::
string
&
filename
,
const
FuncGraphPtr
&
func_graph
,
bool
dump_full_name
=
false
);
void
PrintInputAndOutputInferType
(
std
::
ostringstream
&
buffer
,
const
AnfNodePtr
&
nd
);
const
std
::
string
ToShortString
(
const
TypeId
&
typeId
);
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_DUMP_H_
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_DUMP_H_
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
a04e8486
...
@@ -18,14 +18,15 @@
...
@@ -18,14 +18,15 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include <memory>
#include <memory>
#include <
set
>
#include <
utility
>
#include <
unordered_
map>
#include <map>
#include "kernel/oplib/oplib.h"
#include "kernel/oplib/oplib.h"
#include "kernel/kernel_query.h"
#include "kernel/kernel_query.h"
#include "session/anf_runtime_algorithm.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
#include "kernel/kernel_build_info.h"
#include "utils/context/ms_context.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
#include "operator/ops.h"
#include "debug/anf_ir_dump.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
device
{
namespace
device
{
...
@@ -180,6 +181,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
...
@@ -180,6 +181,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
}
}
void
AddSupportMixedPrecisionDataTypeIndex
(
TypeId
data_type
,
std
::
vector
<
int
>
*
support_index
)
{
void
AddSupportMixedPrecisionDataTypeIndex
(
TypeId
data_type
,
std
::
vector
<
int
>
*
support_index
)
{
MS_EXCEPTION_IF_NULL
(
support_index
);
int
index
=
kUnSupportMixedDataTypeIndex
;
int
index
=
kUnSupportMixedDataTypeIndex
;
switch
(
data_type
)
{
switch
(
data_type
)
{
case
kNumberTypeFloat16
:
case
kNumberTypeFloat16
:
...
@@ -197,6 +199,7 @@ void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *s
...
@@ -197,6 +199,7 @@ void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *s
void
AddKernelInputSupportDataType
(
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
size_t
input_index
,
void
AddKernelInputSupportDataType
(
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
size_t
input_index
,
std
::
vector
<
int
>
*
support_datatype_index
,
std
::
vector
<
TypeId
>
*
support_datatype
)
{
std
::
vector
<
int
>
*
support_datatype_index
,
std
::
vector
<
TypeId
>
*
support_datatype
)
{
MS_EXCEPTION_IF_NULL
(
support_datatype
);
auto
data_type
=
kernel_build_info
.
GetInputDeviceType
(
input_index
);
auto
data_type
=
kernel_build_info
.
GetInputDeviceType
(
input_index
);
support_datatype
->
push_back
(
data_type
);
support_datatype
->
push_back
(
data_type
);
AddSupportMixedPrecisionDataTypeIndex
(
data_type
,
support_datatype_index
);
AddSupportMixedPrecisionDataTypeIndex
(
data_type
,
support_datatype_index
);
...
@@ -204,6 +207,7 @@ void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_i
...
@@ -204,6 +207,7 @@ void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_i
void
AddKernelOutputSupportDataType
(
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
size_t
output_index
,
void
AddKernelOutputSupportDataType
(
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
size_t
output_index
,
std
::
vector
<
int
>
*
support_datatype_index
,
std
::
vector
<
TypeId
>
*
support_datatype
)
{
std
::
vector
<
int
>
*
support_datatype_index
,
std
::
vector
<
TypeId
>
*
support_datatype
)
{
MS_EXCEPTION_IF_NULL
(
support_datatype
);
auto
data_type
=
kernel_build_info
.
GetOutputDeviceType
(
output_index
);
auto
data_type
=
kernel_build_info
.
GetOutputDeviceType
(
output_index
);
support_datatype
->
push_back
(
data_type
);
support_datatype
->
push_back
(
data_type
);
AddSupportMixedPrecisionDataTypeIndex
(
data_type
,
support_datatype_index
);
AddSupportMixedPrecisionDataTypeIndex
(
data_type
,
support_datatype_index
);
...
@@ -238,8 +242,8 @@ void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
...
@@ -238,8 +242,8 @@ void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
void
CheckDataTypeInputs
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
void
CheckDataTypeInputs
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
unordered_
map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatypes
,
const
std
::
map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatypes
,
std
::
unordered_
map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
)
{
std
::
map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
)
{
if
(
node_mix_precision_datatype_index
.
size
()
!=
node_mix_precision_datatype
.
size
())
{
if
(
node_mix_precision_datatype_index
.
size
()
!=
node_mix_precision_datatype
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"node datatype index size "
<<
node_mix_precision_datatype_index
.
size
()
<<
" != datatype size "
MS_LOG
(
EXCEPTION
)
<<
"node datatype index size "
<<
node_mix_precision_datatype_index
.
size
()
<<
" != datatype size "
<<
node_mix_precision_datatype
.
size
();
<<
node_mix_precision_datatype
.
size
();
...
@@ -251,10 +255,11 @@ void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_ind
...
@@ -251,10 +255,11 @@ void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_ind
}
}
}
}
int
RaiseDataTypePrecisionSelect
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
bool
RaiseDataTypePrecisionSelect
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
unordered_map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatypes
,
const
std
::
map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatypes
,
std
::
unordered_map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
)
{
std
::
map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
)
{
MS_EXCEPTION_IF_NULL
(
kernel_match_datatype_idx
);
CheckDataTypeInputs
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatypes
,
CheckDataTypeInputs
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatypes
,
kernel_match_datatype_idx
);
kernel_match_datatype_idx
);
for
(
size_t
i
=
0
;
i
<
node_mix_precision_datatype_index
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
node_mix_precision_datatype_index
.
size
();
++
i
)
{
...
@@ -289,40 +294,16 @@ int RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_data
...
@@ -289,40 +294,16 @@ int RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_data
}
}
}
}
}
}
return
!
kernel_match_datatype_idx
->
empty
();
if
(
kernel_match_datatype_idx
->
size
()
>=
1
)
{
return
SizeToInt
(
kernel_match_datatype_idx
->
begin
()
->
first
);
}
return
-
1
;
}
int
GetMinReducePrecisionCountIndex
(
std
::
unordered_map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
,
const
std
::
unordered_map
<
size_t
,
size_t
>
&
precision_reduce_count
)
{
int
selected_index
=
-
1
;
size_t
min_reduce_precision_count
=
kMaxCount
;
auto
iter
=
kernel_match_datatype_idx
->
begin
();
while
(
iter
!=
kernel_match_datatype_idx
->
end
())
{
auto
find_iter
=
precision_reduce_count
.
find
(
iter
->
first
);
if
(
find_iter
==
precision_reduce_count
.
end
())
{
continue
;
}
if
(
min_reduce_precision_count
>
find_iter
->
second
)
{
selected_index
=
SizeToInt
(
iter
->
first
);
min_reduce_precision_count
=
find_iter
->
second
;
}
++
iter
;
}
return
selected_index
;
}
}
int
RaiseOrReduceDataTypePrecisionSelect
(
bool
RaiseOrReduceDataTypePrecisionSelect
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
unordered_map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatypes
,
const
std
::
map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatypes
,
std
::
unordered_map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
)
{
std
::
map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
)
{
MS_EXCEPTION_IF_NULL
(
kernel_match_datatype_idx
);
CheckDataTypeInputs
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatypes
,
CheckDataTypeInputs
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatypes
,
kernel_match_datatype_idx
);
kernel_match_datatype_idx
);
// reduce / raise
std
::
unordered_map
<
size_t
,
size_t
>
precision_reduce_count
;
for
(
size_t
i
=
0
;
i
<
node_mix_precision_datatype_index
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
node_mix_precision_datatype_index
.
size
();
++
i
)
{
if
(
node_mix_precision_datatype
[
i
]
==
kTypeUnknown
)
{
if
(
node_mix_precision_datatype
[
i
]
==
kTypeUnknown
)
{
continue
;
continue
;
...
@@ -351,26 +332,18 @@ int RaiseOrReduceDataTypePrecisionSelect(
...
@@ -351,26 +332,18 @@ int RaiseOrReduceDataTypePrecisionSelect(
if
(
datatype_indexes
[
i
]
==
kUnSupportMixedDataTypeIndex
)
{
if
(
datatype_indexes
[
i
]
==
kUnSupportMixedDataTypeIndex
)
{
iter
=
kernel_match_datatype_idx
->
erase
(
iter
);
iter
=
kernel_match_datatype_idx
->
erase
(
iter
);
}
else
{
}
else
{
if
(
datatype_indexes
[
i
]
<
node_mix_precision_datatype_index
[
i
])
{
auto
count_iter
=
precision_reduce_count
.
find
(
iter
->
first
);
if
(
count_iter
!=
precision_reduce_count
.
end
())
{
count_iter
->
second
++
;
}
else
{
precision_reduce_count
[
iter
->
first
]
=
1
;
}
}
++
iter
;
++
iter
;
}
}
}
}
}
}
return
!
kernel_match_datatype_idx
->
empty
();
return
GetMinReducePrecisionCountIndex
(
kernel_match_datatype_idx
,
precision_reduce_count
);
}
}
void
AddNodeAndKernelDataType
(
const
CNodePtr
&
kernel_node
,
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
void
AddNodeAndKernelDataType
(
const
CNodePtr
&
kernel_node
,
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
std
::
vector
<
int
>
*
support_indexes
,
std
::
vector
<
TypeId
>
*
node_mix_precision_datatype
,
std
::
vector
<
int
>
*
support_indexes
,
std
::
vector
<
TypeId
>
*
node_mix_precision_datatype
,
std
::
vector
<
TypeId
>
*
support_datatypes
,
std
::
vector
<
TypeId
>
*
support_datatypes
,
std
::
vector
<
int
>
*
node_mix_precision_datatype_index
)
{
std
::
vector
<
int
>
*
node_mix_precision_datatype_index
)
{
MS_EXCEPTION_IF_NULL
(
node_mix_precision_datatype
);
bool
add_node_datatype_flag
=
false
;
bool
add_node_datatype_flag
=
false
;
if
(
node_mix_precision_datatype
->
size
()
==
0
)
{
if
(
node_mix_precision_datatype
->
size
()
==
0
)
{
add_node_datatype_flag
=
true
;
add_node_datatype_flag
=
true
;
...
@@ -390,104 +363,58 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB
...
@@ -390,104 +363,58 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB
}
}
}
}
int
PrecisionReduce
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
void
PrecisionReduce
(
const
std
::
vector
<
int
>
&
node_mix_precision_datatype_index
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
vector
<
TypeId
>
&
node_mix_precision_datatype
,
const
std
::
unordered_map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatype
,
const
std
::
map
<
size_t
,
std
::
vector
<
TypeId
>>
&
kernel_support_datatype
,
std
::
unordered_map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
,
bool
*
precision_reduce
)
{
std
::
map
<
size_t
,
std
::
vector
<
int
>>
*
kernel_match_datatype_idx
,
bool
*
precision_reduce
)
{
MS_EXCEPTION_IF_NULL
(
kernel_match_datatype_idx
);
auto
context_ptr
=
MsContext
::
GetInstance
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
MS_EXCEPTION_IF_NULL
(
context_ptr
);
MS_EXCEPTION_IF_NULL
(
precision_reduce
);
MS_EXCEPTION_IF_NULL
(
precision_reduce
);
std
::
unordered_
map
<
size_t
,
std
::
vector
<
int
>>
kernel_match_datatype_idx_copy
=
*
kernel_match_datatype_idx
;
std
::
map
<
size_t
,
std
::
vector
<
int
>>
kernel_match_datatype_idx_copy
=
*
kernel_match_datatype_idx
;
// raise precision
// raise precision
int
selected_index
=
RaiseDataTypePrecisionSelect
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
bool
selected_ret
=
RaiseDataTypePrecisionSelect
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatype
,
kernel_match_datatype_idx
);
kernel_support_datatype
,
kernel_match_datatype_idx
);
if
(
selected_index
!=
-
1
)
{
if
(
selected_ret
)
{
int
max_match
=
0
;
return
;
auto
iter
=
kernel_match_datatype_idx
->
begin
();
int
match_count
=
0
;
while
(
iter
!=
kernel_match_datatype_idx
->
end
())
{
auto
kernel_datatypes
=
kernel_support_datatype
.
find
(
iter
->
first
);
if
(
kernel_datatypes
==
kernel_support_datatype
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Can not find kernel index"
<<
iter
->
first
<<
"'s datatype."
;
}
if
(
kernel_datatypes
->
second
.
size
()
<
node_mix_precision_datatype
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Kernel datatype size is not equal to node datatype size!"
;
}
for
(
size_t
i
=
0
;
i
<
node_mix_precision_datatype
.
size
();
++
i
)
{
if
(
node_mix_precision_datatype
[
i
]
==
kernel_datatypes
->
second
[
i
])
{
++
match_count
;
}
}
if
(
match_count
>
max_match
)
{
selected_index
=
SizeToInt
(
iter
->
first
);
}
++
iter
;
}
}
}
if
(
selected_index
==
-
1
&&
context_ptr
->
enable_reduce_precision
())
{
if
(
context_ptr
->
enable_reduce_precision
())
{
selected_
index
=
selected_
ret
=
RaiseOrReduceDataTypePrecisionSelect
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
RaiseOrReduceDataTypePrecisionSelect
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatype
,
&
kernel_match_datatype_idx_copy
);
kernel_support_datatype
,
&
kernel_match_datatype_idx_copy
);
}
if
(
selected_index
!=
-
1
)
{
if
(
selected_ret
)
{
*
precision_reduce
=
true
;
*
precision_reduce
=
true
;
}
*
kernel_match_datatype_idx
=
kernel_match_datatype_idx_copy
;
}
}
return
selected_index
;
}
}
void
SelectKernel
(
const
CNodePtr
&
kernel_node
,
bool
precision_reduce
,
const
std
::
vector
<
TypeId
>
&
node_datatype
,
void
PrintRaiseOrReducePrecisionSelectedInfo
(
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
&
selected_kernel_info_ptr
)
{
const
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
&
selected_kernel_build_info
,
MS_EXCEPTION_IF_NULL
(
selected_kernel_info_ptr
);
bool
precision_reduce
)
{
MS_EXCEPTION_IF_NULL
(
selected_kernel_build_info
);
MS_EXCEPTION_IF_NULL
(
cnode
);
std
::
ostringstream
buffer
;
buffer
<<
cnode
->
DebugString
();
if
(
precision_reduce
)
{
if
(
precision_reduce
)
{
std
::
ostringstream
datatype
;
buffer
<<
" reduce precision, node datatype: "
;
size_t
input_num
=
selected_kernel_info_ptr
->
GetInputNum
();
}
else
{
size_t
i
=
0
;
buffer
<<
" raise precision, node datatype: "
;
datatype
<<
"("
;
for
(;
i
<
input_num
&&
i
<
node_datatype
.
size
();
++
i
)
{
datatype
<<
static_cast
<
int
>
(
node_datatype
[
i
]);
if
(
i
<
input_num
-
1
)
{
datatype
<<
", "
;
}
}
datatype
<<
") -> ("
;
for
(;
i
<
node_datatype
.
size
();
++
i
)
{
datatype
<<
static_cast
<
int
>
(
node_datatype
[
i
]);
if
(
i
<
node_datatype
.
size
()
-
1
)
{
datatype
<<
", "
;
}
}
datatype
<<
")"
;
MS_LOG
(
WARNING
)
<<
kernel_node
->
DebugString
()
<<
" reduce precision, node datatype: "
<<
datatype
.
str
()
<<
", select kernel: %s"
<<
selected_kernel_info_ptr
->
ToString
();
}
}
AnfAlgo
::
SetSelectKernelBuildInfo
(
selected_kernel_info_ptr
,
kernel_node
.
get
()
);
PrintInputAndOutputInferType
(
buffer
,
cnode
);
// Set format and data type for input tensor.
buffer
<<
", select kernel:"
<<
selected_kernel_build_info
->
ToString
();
SetTensorDeviceInfo
(
*
selected_kernel_info_ptr
,
kernel_node
);
MS_LOG
(
INFO
)
<<
buffer
.
str
(
);
}
}
}
// namespace
void
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
)
{
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
ChooseMatchedKernelInfo
(
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
const
CNodePtr
&
kernel_node
,
const
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
&
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
if
(
kernel_info_list
.
empty
())
{
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
return
nullptr
;
}
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
};
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
};
int
selected_index
=
-
1
;
size_t
selected_index
=
0
;
std
::
unordered_map
<
size_t
,
std
::
vector
<
int
>>
kernel_match_datatype_idx
;
std
::
unordered_map
<
size_t
,
std
::
vector
<
TypeId
>>
kernel_support_datatype
;
std
::
vector
<
int
>
node_mix_precision_datatype_index
;
std
::
vector
<
TypeId
>
node_mix_precision_datatype
;
for
(
size_t
info_index
=
0
;
info_index
<
kernel_info_list
.
size
();
++
info_index
)
{
for
(
size_t
info_index
=
0
;
info_index
<
kernel_info_list
.
size
();
++
info_index
)
{
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
};
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
};
auto
kernel_build_info
=
*
(
kernel_info_list
[
info_index
]);
auto
kernel_build_info
=
*
(
kernel_info_list
[
info_index
]);
std
::
vector
<
int
>
support_indexes
;
std
::
vector
<
TypeId
>
support_datatypes
;
AddNodeAndKernelDataType
(
kernel_node
,
kernel_build_info
,
&
support_indexes
,
&
node_mix_precision_datatype
,
&
support_datatypes
,
&
node_mix_precision_datatype_index
);
kernel_match_datatype_idx
[
info_index
]
=
support_indexes
;
kernel_support_datatype
[
info_index
]
=
support_datatypes
;
if
(
!
MatchInferOutputDataType
(
kernel_node
,
kernel_build_info
))
{
continue
;
}
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
kernel_info_ptr
=
kernel_info_list
[
info_index
];
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
kernel_info_ptr
=
kernel_info_list
[
info_index
];
UpdateCurMatchCounts
(
*
kernel_info_ptr
,
kernel_node
,
&
cur_kernel_info_match_counts
);
UpdateCurMatchCounts
(
*
kernel_info_ptr
,
kernel_node
,
&
cur_kernel_info_match_counts
);
// Currently the selection policy is the match format count first, and then is datatype counts.
// Currently the selection policy is the match format count first, and then is datatype counts.
...
@@ -495,22 +422,77 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
...
@@ -495,22 +422,77 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
selected_index
=
SizeToInt
(
info_index
);
selected_index
=
SizeToInt
(
info_index
);
}
}
}
}
return
kernel_info_list
[
selected_index
];
}
bool
precision_reduce
=
false
;
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
GetAllMatchedFilteredKernelInfo
(
if
(
selected_index
==
-
1
)
{
const
CNodePtr
&
cnode
,
const
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
&
kernel_info_list
)
{
selected_index
=
PrecisionReduce
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
result
;
kernel_support_datatype
,
&
kernel_match_datatype_idx
,
&
precision_reduce
);
for
(
const
auto
&
kernel_build_info
:
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info
);
if
(
!
MatchInferOutputDataType
(
cnode
,
*
kernel_build_info
))
{
continue
;
}
result
.
push_back
(
kernel_build_info
);
}
}
if
(
selected_index
==
-
1
)
{
return
result
;
MS_LOG
(
EXCEPTION
)
<<
kernel_node
->
DebugString
()
<<
"Cannot find valid kernel Info !"
;
}
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
FilterRaisedOrReducePrecisionMatchedKernelInfo
(
const
CNodePtr
&
cnode
,
const
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
&
kernel_info_list
,
bool
*
precision_reduce
)
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
filtered_kernel_info_list
;
std
::
map
<
size_t
,
std
::
vector
<
int
>>
kernel_match_datatype_idx
;
std
::
map
<
size_t
,
std
::
vector
<
TypeId
>>
kernel_support_datatype
;
std
::
vector
<
int
>
node_mix_precision_datatype_index
;
std
::
vector
<
TypeId
>
node_mix_precision_datatype
;
for
(
size_t
info_index
=
0
;
info_index
<
kernel_info_list
.
size
();
++
info_index
)
{
std
::
vector
<
int
>
support_indexes
;
std
::
vector
<
TypeId
>
support_datatypes
;
MS_EXCEPTION_IF_NULL
(
kernel_info_list
[
info_index
]);
AddNodeAndKernelDataType
(
cnode
,
*
kernel_info_list
[
info_index
],
&
support_indexes
,
&
node_mix_precision_datatype
,
&
support_datatypes
,
&
node_mix_precision_datatype_index
);
kernel_match_datatype_idx
[
info_index
]
=
support_indexes
;
kernel_support_datatype
[
info_index
]
=
support_datatypes
;
}
}
auto
index
=
IntToSize
(
selected_index
);
PrecisionReduce
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatype
,
if
(
index
>=
kernel_info_list
.
size
())
{
&
kernel_match_datatype_idx
,
precision_reduce
);
MS_LOG
(
EXCEPTION
)
<<
"index outof range"
;
std
::
transform
(
kernel_match_datatype_idx
.
begin
(),
kernel_match_datatype_idx
.
end
(),
std
::
back_inserter
(
filtered_kernel_info_list
),
[
&
](
const
std
::
pair
<
size_t
,
std
::
vector
<
int
>>
&
matched_idx
)
->
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
{
return
kernel_info_list
[
matched_idx
.
first
];
});
return
filtered_kernel_info_list
;
}
}
// namespace
void
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
)
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
MS_EXCEPTION_IF_NULL
(
kernel_node
);
bool
precision_reduce
=
false
;
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
selected_kernel_info
=
nullptr
;
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
// filter kernel info matched with me infered type
auto
filtered_kernel_info_list
=
GetAllMatchedFilteredKernelInfo
(
kernel_node
,
kernel_info_list
);
if
(
!
filtered_kernel_info_list
.
empty
())
{
selected_kernel_info
=
ChooseMatchedKernelInfo
(
kernel_node
,
filtered_kernel_info_list
);
}
else
{
// selected kernel info using raised precision or reduce precision
filtered_kernel_info_list
=
FilterRaisedOrReducePrecisionMatchedKernelInfo
(
kernel_node
,
kernel_info_list
,
&
precision_reduce
);
selected_kernel_info
=
ChooseMatchedKernelInfo
(
kernel_node
,
filtered_kernel_info_list
);
if
(
selected_kernel_info
==
nullptr
)
{
std
::
ostringstream
buffer
;
PrintInputAndOutputInferType
(
buffer
,
kernel_node
);
MS_LOG
(
EXCEPTION
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid kernel info, not supported the type"
<<
buffer
.
str
();
}
else
{
PrintRaiseOrReducePrecisionSelectedInfo
(
kernel_node
,
selected_kernel_info
,
precision_reduce
);
}
}
}
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
selected_kernel_info_ptr
=
kernel_info_list
[
index
]
;
AnfAlgo
::
SetSelectKernelBuildInfo
(
selected_kernel_info
,
kernel_node
.
get
())
;
MS_EXCEPTION_IF_NULL
(
selected_kernel_info_ptr
);
// Set format and data type for input tensor.
Se
lectKernel
(
kernel_node
,
precision_reduce
,
node_mix_precision_datatype
,
selected_kernel_info_ptr
);
Se
tTensorDeviceInfo
(
*
selected_kernel_info
,
kernel_node
);
}
}
bool
CheckKernelAccuracySupported
(
const
CNodePtr
&
kernel_node
,
bool
CheckKernelAccuracySupported
(
const
CNodePtr
&
kernel_node
,
...
...
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
a04e8486
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "kernel/kernel_build_info.h"
#include "kernel/kernel_build_info.h"
#include <algorithm>
#include <algorithm>
#include "utils/log_adapter.h"
#include "utils/log_adapter.h"
#include "debug/anf_ir_dump.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
std
::
string
KernelBuildInfo
::
GetInputFormat
(
size_t
input_index
)
const
{
std
::
string
KernelBuildInfo
::
GetInputFormat
(
size_t
input_index
)
const
{
...
@@ -82,14 +83,14 @@ std::string KernelBuildInfo::ToString() const {
...
@@ -82,14 +83,14 @@ std::string KernelBuildInfo::ToString() const {
if
(
index
!=
0
)
{
if
(
index
!=
0
)
{
output_buffer
<<
", "
;
output_buffer
<<
", "
;
}
}
output_buffer
<<
"<"
<<
static_cast
<
int
>
(
GetInputDeviceType
(
index
))
<<
"x"
<<
GetInputFormat
(
index
)
<<
">"
;
output_buffer
<<
"<"
<<
ToShortString
(
GetInputDeviceType
(
index
))
<<
"x"
<<
GetInputFormat
(
index
)
<<
">"
;
}
}
output_buffer
<<
") -> ("
;
output_buffer
<<
") -> ("
;
for
(
size_t
index
=
0
;
index
<
GetOutputNum
();
++
index
)
{
for
(
size_t
index
=
0
;
index
<
GetOutputNum
();
++
index
)
{
if
(
index
!=
0
)
{
if
(
index
!=
0
)
{
output_buffer
<<
", "
;
output_buffer
<<
", "
;
}
}
output_buffer
<<
"<"
<<
static_cast
<
int
>
(
GetOutputDeviceType
(
index
))
<<
"x"
<<
GetOutputFormat
(
index
)
<<
">"
;
output_buffer
<<
"<"
<<
ToShortString
(
GetOutputDeviceType
(
index
))
<<
"x"
<<
GetOutputFormat
(
index
)
<<
">"
;
}
}
output_buffer
<<
")"
;
output_buffer
<<
")"
;
return
output_buffer
.
str
();
return
output_buffer
.
str
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录