Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
691b0648
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
691b0648
编写于
5月 12, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert unsupported kernel in aicore to aicpu
上级
7ab3f5c3
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
266 addition
and
471 deletion
+266
-471
mindspore/ccsrc/common/trans.cc
mindspore/ccsrc/common/trans.cc
+42
-9
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+26
-33
mindspore/ccsrc/device/ascend/kernel_select_ascend.h
mindspore/ccsrc/device/ascend/kernel_select_ascend.h
+7
-2
mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc
mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc
+1
-1
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+39
-7
mindspore/ccsrc/kernel/kernel_query.h
mindspore/ccsrc/kernel/kernel_query.h
+3
-1
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
+0
-5
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+1
-0
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
+7
-6
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
...end/format_type/convert_unsupported_transnode_to_aicpu.cc
+54
-0
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h
...cend/format_type/convert_unsupported_transnode_to_aicpu.h
+37
-0
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h
...c/pre_activate/ascend/format_type/insert_cast_for_runop.h
+3
-3
mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h
..._activate/ascend/format_type/insert_transdata_for_runop.h
+3
-3
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
+1
-1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc
...e_activate/ascend/ir_fusion/transpose_transdata_fusion.cc
+1
-1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h
...re_activate/ascend/ir_fusion/transpose_transdata_fusion.h
+4
-2
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+2
-2
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+17
-2
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+5
-15
tests/ut/cpp/device/ascend_kernel_select_test.cc
tests/ut/cpp/device/ascend_kernel_select_test.cc
+0
-345
tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
.../ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
+1
-1
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
...ivate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
+10
-32
未找到文件。
mindspore/ccsrc/common/trans.cc
浏览文件 @
691b0648
...
...
@@ -85,7 +85,7 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType
} while (0)
template
<
typename
T
>
T
Ceil
(
T
n1
,
T
n2
)
{
T
Div
Ceil
(
T
n1
,
T
n2
)
{
return
(
n2
!=
0
)
?
(
n1
-
1
)
/
n2
+
1
:
0
;
}
...
...
@@ -371,15 +371,48 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
std
::
vector
<
size_t
>
FracZc04DeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
c0
=
4
;
size_t
first_dim
=
DivCeil
(
c0
*
shape
[
2
]
*
shape
[
3
],
kCubeSize
);
size_t
no
=
DivCeil
(
DivCeil
(
shape
[
0
],
kCubeSize
)
*
kCubeSize
,
kCubeSize
);
device_shape
.
push_back
(
first_dim
);
device_shape
.
push_back
(
no
);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
std
::
vector
<
size_t
>
Nc1hwc04DeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
!
CheckDims
(
shape
))
{
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
C1
=
1
;
size_t
C0
=
4
;
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
C1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
}
// namespace
std
::
vector
<
size_t
>
TransShapeToDevice
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
using
DeviceShapeTransfer
=
std
::
function
<
std
::
vector
<
size_t
>
(
const
std
::
vector
<
size_t
>
&
)
>
;
const
std
::
map
<
std
::
string
,
DeviceShapeTransfer
>
device_shape_map
{
{
kOpFormat_NCHW
,
NchwDeviceShape
},
{
kOpFormat_NHWC
,
NhwcDeviceShape
},
{
kOpFormat_HWCN
,
HwchDeviceShape
},
{
kOpFormat_FRAC_Z
,
FracZDeviceShape
},
{
kOpFormat_NC1HWC0
,
Nc1hwc0DeviceShape
},
{
kOpFormat_C1HWNCoC0
,
C1hwncoc0DeviceShape
},
};
const
std
::
map
<
std
::
string
,
DeviceShapeTransfer
>
device_shape_map
{{
kOpFormat_NCHW
,
NchwDeviceShape
},
{
kOpFormat_NHWC
,
NhwcDeviceShape
},
{
kOpFormat_HWCN
,
HwchDeviceShape
},
{
kOpFormat_FRAC_Z
,
FracZDeviceShape
},
{
kOpFormat_NC1HWC0
,
Nc1hwc0DeviceShape
},
{
kOpFormat_C1HWNCoC0
,
C1hwncoc0DeviceShape
},
{
kOpFormat_FRACTAL_Z_C04
,
FracZc04DeviceShape
},
{
kOpFormat_NC1HWC0_C04
,
Nc1hwc04DeviceShape
}};
if
(
format
==
kOpFormat_ND
||
format
==
kOpFormat_DEFAULT
)
{
return
shape
;
...
...
@@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
size_t
c1
=
Ceil
(
c
,
c0
);
size_t
c1
=
Div
Ceil
(
c
,
c0
);
size_t
hw
=
h
*
w
;
size_t
chw
=
c
*
hw
;
size_t
hwc0
=
hw
*
c0
;
size_t
nchw
=
n
*
chw
;
size_t
hf_cnt
=
Ceil
(
n
,
kCubeSize
);
size_t
hf_cnt
=
Div
Ceil
(
n
,
kCubeSize
);
size_t
vf_cnt
=
c1
*
hw
;
size_t
fractal_ele_cnt
=
c0
*
kCubeSize
;
size_t
total_ele_cnt
=
hf_cnt
*
vf_cnt
*
fractal_ele_cnt
;
...
...
@@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
size_t
c1
=
Ceil
(
c
,
c0
);
size_t
c1
=
Div
Ceil
(
c
,
c0
);
size_t
hw
=
h
*
w
;
size_t
chw
=
c
*
hw
;
size_t
c1hwc0
=
c1
*
hw
*
c0
;
...
...
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
691b0648
...
...
@@ -34,6 +34,7 @@ namespace ascend {
namespace
{
const
float
kWegihtBaseScore
=
1
;
const
float
kFeatureMapBaseScore
=
10
;
constexpr
auto
kPriChoosenFormat
=
"pri_format"
;
enum
MatchCountPriority
:
int
{
MATCH_COUNT_PRIORITY_BEGIN
=
0
,
MATCH_DTYPE_COUNT
=
MATCH_COUNT_PRIORITY_BEGIN
,
...
...
@@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
if
(
need_change_nd
)
{
priority_matched_format
=
kOpFormat_DEFAULT
;
}
AnfAlgo
::
SetNodeAttr
(
kPriChoosenFormat
,
MakeValue
(
priority_matched_format
),
cnode
);
return
priority_matched_format
;
}
/**
...
...
@@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
std
::
ostringstream
buffer
;
buffer
<<
cnode
->
DebugString
();
if
(
precision_reduce
)
{
buffer
<<
" reduce precision, node datatype: "
;
buffer
<<
" reduce precision, node datatype:
\n
"
;
}
else
{
buffer
<<
" raise precision, node datatype: "
;
buffer
<<
" raise precision, node datatype:
\n
"
;
}
PrintInputAndOutputInferType
(
buffer
,
cnode
);
buffer
<<
", select kernel:"
<<
selected_kernel_build_info
->
ToString
();
...
...
@@ -464,66 +466,57 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
}
// namespace
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
CanHitKernelInfo
(
int
*
status
,
const
CNodePtr
&
kernel_node
,
const
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
&
kernel_info_list
)
{
KernelSelectStatus
SetMatchedKernelInfo
(
const
CNodePtr
&
kernel_node
,
const
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
&
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
KernelSelectStatus
select_status
=
kNoMatched
;
bool
precision_reduce
=
false
;
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
selected_kernel_info
=
nullptr
;
// Matched kernel info
// Filter kernel info matched with me infered type
auto
filtered_kernel_info_list
=
GetAllMatchedFilteredKernelInfo
(
kernel_node
,
kernel_info_list
);
if
(
!
filtered_kernel_info_list
.
empty
())
{
selected_kernel_info
=
ChooseMatchedKernelInfo
(
kernel_node
,
filtered_kernel_info_list
);
select_status
=
kStatusAllMatched
;
}
else
{
// selected kernel info using raised precision or reduce precision
filtered_kernel_info_list
=
FilterRaisedOrReducePrecisionMatchedKernelInfo
(
kernel_node
,
kernel_info_list
,
&
precision_reduce
);
selected_kernel_info
=
ChooseMatchedKernelInfo
(
kernel_node
,
filtered_kernel_info_list
);
if
(
selected_kernel_info
==
nullptr
)
{
return
nullptr
;
return
select_status
;
}
else
{
PrintRaiseOrReducePrecisionSelectedInfo
(
kernel_node
,
selected_kernel_info
,
precision_reduce
);
*
status
=
precision_reduce
?
kStatusReducePrecision
:
kStatusRaisePrecision
;
select_
status
=
precision_reduce
?
kStatusReducePrecision
:
kStatusRaisePrecision
;
}
}
return
selected_kernel_info
;
// Set kernel info to the anfnode
AnfAlgo
::
SetSelectKernelBuildInfo
(
selected_kernel_info
,
kernel_node
.
get
());
// Set format and data type for input tensor.
SetTensorDeviceInfo
(
*
selected_kernel_info
,
kernel_node
);
return
select_status
;
}
int
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
)
{
KernelSelectStatus
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
)
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
int
status
=
kStatusAllMatched
;
MS_EXCEPTION_IF_NULL
(
kernel_node
);
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
// filter kernel info matched with me infered type
auto
selected_kernel_info
=
CanHitKernelInfo
(
&
status
,
kernel_node
,
kernel_info_list
);
if
(
select
ed_kernel_info
==
nullptr
)
{
auto
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
if
(
select
_status
==
kNoMatched
)
{
MS_LOG
(
WARNING
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
kernel
::
A
ic
puQuery
(
kernel_node
,
&
kernel_info_list
);
select
ed_kernel_info
=
CanHitKernelInfo
(
&
status
,
kernel_node
,
kernel_info_list
);
kernel
::
A
IC
puQuery
(
kernel_node
,
&
kernel_info_list
);
select
_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
}
if
(
selected_kernel_info
==
nullptr
)
{
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
if
(
select_status
==
kNoMatched
)
{
std
::
ostringstream
buffer
;
PrintInputAndOutputInferType
(
buffer
,
kernel_node
);
MS_EXCEPTION
(
TypeError
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid kernel info, not supported the type "
<<
buffer
.
str
();
}
AnfAlgo
::
SetSelectKernelBuildInfo
(
selected_kernel_info
,
kernel_node
.
get
());
// Set format and data type for input tensor.
SetTensorDeviceInfo
(
*
selected_kernel_info
,
kernel_node
);
return
status
;
}
bool
CheckKernelAccuracySupported
(
const
CNodePtr
&
kernel_node
,
const
kernel
::
KernelBuildInfoPtr
&
new_kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
auto
result
=
std
::
find_if
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
[
&
new_kernel_build_info
](
const
kernel
::
KernelBuildInfoPtr
item
)
{
MS_EXCEPTION_IF_NULL
(
item
);
return
*
item
==
*
new_kernel_build_info
;
});
return
result
!=
kernel_info_list
.
end
();
return
select_status
;
}
}
// namespace ascend
}
// namespace device
...
...
mindspore/ccsrc/device/ascend/kernel_select_ascend.h
浏览文件 @
691b0648
...
...
@@ -21,8 +21,13 @@
namespace
mindspore
{
namespace
device
{
namespace
ascend
{
int
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
);
bool
CheckKernelAccuracySupported
(
const
CNodePtr
&
kernel_node
,
const
kernel
::
KernelBuildInfoPtr
&
new_kernel_build_info
);
enum
KernelSelectStatus
{
kNoMatched
=
-
1
,
kStatusAllMatched
=
0
,
kStatusReducePrecision
=
1
,
kStatusRaisePrecision
=
2
,
};
KernelSelectStatus
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
);
}
// namespace ascend
}
// namespace device
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc
浏览文件 @
691b0648
...
...
@@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std
::
vector
<
std
::
string
>
input_format
,
output_format
;
std
::
vector
<
TypeId
>
input_type
,
output_type
;
for
(
const
auto
&
data_type
:
data_type_list
)
{
for
(
const
auto
&
format
:
k
4DSupportForma
t
)
{
for
(
const
auto
&
format
:
k
OpFormatLis
t
)
{
auto
builder
=
std
::
make_shared
<
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
input_format
.
clear
();
input_format
.
push_back
(
format
);
...
...
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
691b0648
...
...
@@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
return
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
)
==
kernel_build_info
->
GetOutputNum
()
&&
AnfAlgo
::
GetInputTensorNum
(
kernel_node
)
==
kernel_build_info
->
GetInputNum
();
});
kernel_info_list
->
clear
();
if
(
!
filtered_list
.
empty
())
{
kernel_info_list
->
clear
();
(
void
)
std
::
copy
(
filtered_list
.
begin
(),
filtered_list
.
end
(),
std
::
back_inserter
(
*
kernel_info_list
));
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"node"
<<
kernel_node
->
DebugString
()
<<
"'s output size : ["
<<
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
)
<<
"]"
<<
"input size : ["
<<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
)
<<
"] cannot match any kernelInfo !"
;
MS_LOG
(
WARNING
)
<<
"All kernel Info list does not match any kernel info "
;
for
(
size_t
index
;
index
<
kernel_info_list
->
size
();
++
index
)
{
MS_EXCEPTION_IF_NULL
(
kernel_info_list
->
at
(
index
));
MS_LOG
(
WARNING
)
<<
"kernel [ "
<<
index
<<
" ] :"
<<
kernel_info_list
->
at
(
index
)
->
ToString
();
}
MS_LOG
(
WARNING
)
<<
"node"
<<
kernel_node
->
DebugString
()
<<
"'s output size : ["
<<
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
)
<<
"]"
<<
"input size : ["
<<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
)
<<
"] cannot match any kernelInfo !"
;
}
}
}
// namespace
...
...
@@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
TbeMetadataInfo
(
kernel_node
,
kernel_info_list
);
if
(
kernel_info_list
->
empty
())
{
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
}
...
...
@@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
}
void
A
ic
puQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
void
A
IC
puQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
kernel_info_list
->
clear
();
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
}
bool
IsSupportedByAiCpu
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
select_kernel_build_info
);
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
auto
cnode
=
kernel_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
AicpuMetadataInfo
(
cnode
,
&
kernel_info_list
);
FilterInvalidKernelInfo
(
cnode
,
&
kernel_info_list
);
return
std
::
any_of
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
[
&
select_kernel_build_info
](
const
kernel
::
KernelBuildInfoPtr
item
)
{
MS_EXCEPTION_IF_NULL
(
item
);
return
*
item
==
*
select_kernel_build_info
;
});
}
bool
IsSupportedByAiCore
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
select_kernel_build_info
);
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
auto
cnode
=
kernel_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
TbeMetadataInfo
(
cnode
,
&
kernel_info_list
);
FilterInvalidKernelInfo
(
cnode
,
&
kernel_info_list
);
return
std
::
any_of
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
[
&
select_kernel_build_info
](
const
kernel
::
KernelBuildInfoPtr
item
)
{
MS_EXCEPTION_IF_NULL
(
item
);
return
*
item
==
*
select_kernel_build_info
;
});
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/kernel_query.h
浏览文件 @
691b0648
...
...
@@ -26,7 +26,9 @@
namespace
mindspore
{
namespace
kernel
{
void
KernelQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
void
AicpuQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
void
AICpuQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
bool
IsSupportedByAiCpu
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
bool
IsSupportedByAiCore
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
浏览文件 @
691b0648
...
...
@@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
}
bool
IsShapeMatchFormat
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
const
std
::
set
<
std
::
string
>
kOpFormatList
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_Z
,
kOpFormat_C1HWNCoC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
};
// if format is default, it remarkes support all format
if
(
kOpFormatList
.
find
(
format
)
==
kOpFormatList
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Got the unknown format "
<<
format
;
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
691b0648
...
...
@@ -54,6 +54,7 @@
#include "pre_activate/pass/optimize_dependence.h"
#include "pre_activate/pass/erase_visit_attr.h"
#include "pre_activate/ascend/format_type/insert_cast.h"
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include "pre_activate/pass/eliminate_redundant_op.h"
#include "pre_activate/pass/common_subexpression_elimination.h"
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
...
...
@@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
MergeCastToOp
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
LayerNormBetaGammaBackpropFusion
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
mixed_precision_pm
->
AddPass
(
std
::
make_shared
<
ConvertUnSupportNodeToAICPU
>
());
optimizer
->
AddPassManager
(
mixed_precision_pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
691b0648
...
...
@@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
}
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
cast
.
get
());
AnfAlgo
::
SetOutputInferTypeAndShape
({
origin_type
},
{
origin_shape
},
cast
.
get
());
AnfAlgo
::
SetNodeAttr
(
kIsBackendCast
,
MakeValue
(
true
),
cast
);
return
cast
;
}
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
浏览文件 @
691b0648
...
...
@@ -30,10 +30,6 @@ class KernelSelect {
KernelSelect
()
=
default
;
virtual
~
KernelSelect
()
=
default
;
virtual
void
SelectKernel
(
const
CNodePtr
&
cnode
)
{
device
::
ascend
::
SelectKernelInfo
(
cnode
);
}
virtual
bool
CheckKernelAccuracySupported
(
const
CNodePtr
&
kernel_node
,
const
kernel
::
KernelBuildInfoPtr
&
new_kernel_build_info
)
{
return
device
::
ascend
::
CheckKernelAccuracySupported
(
kernel_node
,
new_kernel_build_info
);
}
};
using
KernelSelectPtr
=
std
::
shared_ptr
<
KernelSelect
>
;
...
...
@@ -41,8 +37,13 @@ class SupportedChecker {
public:
SupportedChecker
()
=
default
;
virtual
~
SupportedChecker
()
=
default
;
virtual
bool
CheckSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
return
kernel
::
CheckSupported
(
anf_node
,
select_kernel_build_info
);
virtual
bool
CheckAiCoreSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
return
kernel
::
IsSupportedByAiCore
(
anf_node
,
select_kernel_build_info
);
}
virtual
bool
CheckAiCpuSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
return
kernel
::
IsSupportedByAiCpu
(
anf_node
,
select_kernel_build_info
);
}
};
using
SupportedCheckerPtr
=
std
::
shared_ptr
<
SupportedChecker
>
;
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
0 → 100644
浏览文件 @
691b0648
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
#include "kernel/kernel_query.h"
namespace
mindspore
{
namespace
opt
{
const
BaseRef
ConvertUnSupportNodeToAICPU
::
DefinePattern
()
const
{
VarPtr
X
=
std
::
make_shared
<
Var
>
();
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
return
VectorRef
({
X
,
Xs
});
}
const
AnfNodePtr
ConvertUnSupportNodeToAICPU
::
Process
(
const
mindspore
::
FuncGraphPtr
&
,
const
mindspore
::
AnfNodePtr
&
node
,
const
mindspore
::
EquivPtr
&
)
const
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
return
nullptr
;
}
auto
node_name
=
AnfAlgo
::
GetCNodeName
(
node
);
if
(
node_name
!=
prim
::
KPrimTransData
->
name
()
||
node_name
!=
prim
::
kPrimCast
->
name
())
{
return
nullptr
;
}
auto
kernel_builder_info
=
AnfAlgo
::
GetSelectKernelBuildInfo
(
node
);
if
(
supported_checker_
->
CheckAiCoreSupported
(
node
,
kernel_builder_info
))
{
return
node
;
}
else
if
(
supported_checker_
->
CheckAiCpuSupported
(
node
,
kernel_builder_info
))
{
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
kernel_builder_info
);
builder
->
SetKernelType
(
AICPU_KERNEL
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
node
.
get
());
}
else
{
MS_LOG
(
EXCEPTION
)
<<
" kernel "
<<
kernel_builder_info
->
ToString
()
<<
"is not supported in AiCPU & AiCore : node ["
<<
node
->
DebugString
()
<<
"]"
;
}
return
node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h
0 → 100644
浏览文件 @
691b0648
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
namespace
mindspore
{
namespace
opt
{
class
ConvertUnSupportNodeToAICPU
:
public
PatternProcessPass
{
public:
explicit
ConvertUnSupportNodeToAICPU
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"convert_unsupported_node_to_aicpu"
,
multigraph
),
supported_checker_
(
std
::
make_shared
<
SupportedChecker
>
())
{}
~
ConvertUnSupportNodeToAICPU
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
SupportedCheckerPtr
supported_checker_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h
浏览文件 @
691b0648
...
...
@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_
DEVICE_OPTIMIZER_FORMAT_TYPE_PASS
_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_
DEVICE_OPTIMIZER_FORMAT_TYPE_PASS
_INSERT_CAST_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_
PRE_ACTIVATE_ASCEND_FORMAT_TYPE
_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_
PRE_ACTIVATE_ASCEND_FORMAT_TYPE
_INSERT_CAST_FOR_RUNOP_H_
#include <string>
#include "pre_activate/common/optimizer.h"
...
...
@@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass {
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_
DEVICE_OPTIMIZER_FORMAT_TYPE_PASS
_INSERT_CAST_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_
PRE_ACTIVATE_ASCEND_FORMAT_TYPE
_INSERT_CAST_FOR_RUNOP_H_
mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h
浏览文件 @
691b0648
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_
DEVICE_OPTIMIZER_FORMAT_TYPE_PASS
_INSERT_TRANSDATA_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_
DEVICE_OPTIMIZER_FORMAT_TYPE_PASS
_INSERT_TRANSDATA_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_
PRE_ACTIVATE_ASCEND_FORMAT_TYPE
_INSERT_TRANSDATA_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_
PRE_ACTIVATE_ASCEND_FORMAT_TYPE
_INSERT_TRANSDATA_FOR_RUNOP_H_
#include <string>
#include <utility>
...
...
@@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass {
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_
DEVICE_OPTIMIZER_FORMAT_TYPE_PASS
_INSERT_TRANSDATA_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_
PRE_ACTIVATE_ASCEND_FORMAT_TYPE
_INSERT_TRANSDATA_FOR_RUNOP_H_
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
浏览文件 @
691b0648
...
...
@@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
auto
indices_const
=
CreateValueNode
(
new_cnode
);
new_cnode
->
add_input
(
indices_const
);
MS_EXCEPTION_IF_NULL
(
supported_checker_
);
if
(
!
supported_checker_
->
CheckSupported
(
new_cnode
,
CreateKernelBuildInfo
()))
{
if
(
!
supported_checker_
->
Check
AiCore
Supported
(
new_cnode
,
CreateKernelBuildInfo
()))
{
return
nullptr
;
}
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc
浏览文件 @
691b0648
...
...
@@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
new_transdata_builder
->
SetProcessor
(
transdata_kernel_build_info
->
processor
());
auto
new_fusion_transdata
=
std
::
make_shared
<
Primitive
>
(
kTransDataOpName
);
if
(
kernel_select_
->
CheckKernelAccuracy
Supported
(
transdata_cnode
,
new_transdata_builder
->
Build
()))
{
if
(
supported_checker_
->
CheckAiCore
Supported
(
transdata_cnode
,
new_transdata_builder
->
Build
()))
{
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
new_fusion_transdata
),
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input_varptr_
])};
auto
new_node
=
func_graph
->
NewCNode
(
inputs
);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h
浏览文件 @
691b0648
...
...
@@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass {
explicit
TransposeTransDataFusion
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"transpose_transdata_fusion"
,
multigraph
)
{
input_varptr_
=
std
::
make_shared
<
Var
>
();
kernel_select_
=
std
::
make_shared
<
KernelSelect
>
();
supported_checker_
=
std
::
make_shared
<
SupportedChecker
>
();
}
~
TransposeTransDataFusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
...
...
@@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass {
private:
VarPtr
input_varptr_
;
KernelSelectPtr
kernel_select_
;
private:
SupportedCheckerPtr
supported_checker_
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
691b0648
...
...
@@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
size_t
reduce_precision_count
=
0
;
for
(
const
auto
&
cnode
:
kernel_graph
.
execution_order
())
{
auto
status
=
device
::
ascend
::
SelectKernelInfo
(
cnode
);
if
(
status
==
kStatusRaisePrecision
)
{
if
(
status
==
device
::
ascend
::
kStatusRaisePrecision
)
{
raise_precision_count
++
;
}
else
if
(
status
==
kStatusReducePrecision
)
{
}
else
if
(
status
==
device
::
ascend
::
kStatusReducePrecision
)
{
reduce_precision_count
++
;
}
MS_LOG
(
INFO
)
<<
"Select ApplyKernel: "
<<
cnode
->
DebugString
();
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
691b0648
...
...
@@ -27,6 +27,8 @@
namespace
mindspore
{
namespace
session
{
namespace
{
constexpr
auto
kIsFeatureMapOutput
=
"IsFeatureMapOutput"
;
constexpr
auto
kIsFeatureMapInputList
=
"IsFeatureMapInputList"
;
void
PushNoVisitedNode
(
const
AnfNodePtr
&
node
,
std
::
queue
<
AnfNodePtr
>
*
que
,
std
::
unordered_set
<
AnfNodePtr
>
*
visited_nodes
)
{
MS_EXCEPTION_IF_NULL
(
que
);
...
...
@@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
cnode
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractNone
>
());
// create kernel_info from new parameter
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
std
::
vector
<
size_t
>
feature_map_input_indexs
;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
if
(
inputs
.
size
()
==
1
||
std
::
any_of
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
[
&
](
const
AnfNodePtr
&
node
)
{
return
AnfAlgo
::
IsFeatureMapOutput
(
node
);
}))
{
for
(
size_t
index
=
1
;
index
<
inputs
.
size
();
++
index
)
{
auto
node
=
inputs
[
index
];
if
(
AnfAlgo
::
IsFeatureMapOutput
(
node
))
{
feature_map_input_indexs
.
push_back
(
index
);
}
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
==
prim
::
kPrimCast
->
name
())
{
AnfAlgo
::
SetNodeAttr
(
kIsBackendCast
,
MakeValue
(
false
),
cnode
);
}
if
(
inputs
.
size
()
==
1
||
!
feature_map_input_indexs
.
empty
())
{
kernel_info
->
SetFeatureMapFlag
(
true
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
true
),
cnode
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapInputList
,
MakeValue
(
feature_map_input_indexs
),
cnode
);
}
else
{
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
false
),
cnode
);
}
cnode
->
set_kernel_info
(
kernel_info
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
cnode
.
get
());
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
691b0648
...
...
@@ -139,6 +139,7 @@ constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2D
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
constexpr
auto
kIsBackendCast
=
"is_backed_cast"
;
constexpr
auto
kAttrOutputNames
=
"output_names"
;
constexpr
auto
kAttrVisited
=
"visited"
;
constexpr
auto
kAttrShape
=
"shape"
;
...
...
@@ -196,10 +197,6 @@ constexpr auto kControlDependBehindIndex = 2;
// index define of depend
constexpr
auto
kRealInputIndexInDepend
=
1
;
constexpr
auto
kDependAttachNodeIndex
=
2
;
// status of kernel select result
const
int
kStatusReducePrecision
=
-
1
;
const
int
kStatusRaisePrecision
=
1
;
const
int
kStatusAllMatched
=
0
;
// format
constexpr
auto
kOpFormat_DEFAULT
=
"DefaultFormat"
;
constexpr
auto
kOpFormat_NC1KHKWHWC0
=
"NC1KHKWHWC0"
;
...
...
@@ -213,18 +210,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr
auto
kOpFormat_C1HWNCoC0
=
"C1HWNCoC0"
;
constexpr
auto
kOpFormat_NC1HWC0_C04
=
"NC1HWC0_C04"
;
constexpr
auto
kOpFormat_FRACTAL_Z_C04
=
"FRACTAL_Z_C04"
;
const
std
::
set
<
std
::
string
>
k1DSupportFormat
=
{
kOpFormat_DEFAULT
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
kOpFormat_C1HWNCoC0
};
const
std
::
set
<
std
::
string
>
k2DSupportFormat
=
{
kOpFormat_DEFAULT
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
};
const
std
::
set
<
std
::
string
>
k3DSupportFormat
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
};
const
std
::
set
<
std
::
string
>
k4DSupportFormat
=
k1DSupportFormat
;
const
std
::
vector
<
std
::
set
<
std
::
string
>>
kShapeSupportFormatMap
=
{
k1DSupportFormat
,
k2DSupportFormat
,
k3DSupportFormat
,
k4DSupportFormat
};
const
std
::
set
<
std
::
string
>
kOpFormatList
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_Z
,
kOpFormat_C1HWNCoC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
};
const
std
::
set
<
std
::
string
>
kDefaultCompatibleFormat
=
{
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
};
const
std
::
set
<
std
::
string
>
kOptOperatorSet
=
{
kMomentumOpName
,
kApplyMomentumOpName
,
kApplyAdadeltaOpName
,
kApplyAdagradOpName
,
kApplyAdagradDAName
,
kApplyAdamOpName
,
...
...
tests/ut/cpp/device/ascend_kernel_select_test.cc
已删除
100644 → 0
浏览文件 @
7ab3f5c3
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h"
#include "common/common_test.h"
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "operator/ops.h"
#include "mindspore/ccsrc/device/kernel_info.h"
#include "mindspore/ccsrc/kernel/kernel_build_info.h"
#include <vector>
namespace
mindspore
{
namespace
device
{
namespace
ascend
{
namespace
{
using
KernelInfo
=
device
::
KernelInfo
;
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
using
KernelBuildInfo
=
kernel
::
KernelBuildInfo
;
using
KernelGraph
=
session
::
KernelGraph
;
using
KernelBuildInfoPtr
=
std
::
shared_ptr
<
KernelBuildInfo
>
;
using
KernelBuilderPtr
=
std
::
shared_ptr
<
KernelBuildInfoBuilder
>
;
using
Shape
=
std
::
vector
<
size_t
>
;
using
ShapeList
=
std
::
vector
<
Shape
>
;
enum
MatchCountPriority
{
MATCH_COUNT_PRIORITY_BEGIN
=
0
,
MATCH_FORMAT_COUNT
=
MATCH_COUNT_PRIORITY_BEGIN
,
MATCH_DTYPE_COUNT
,
MATCH_NZ_FORMAT_COUNT
,
MATCH_5D_FORMAT_COUNT
,
MATCH_OUTPUT_DTYPE_COUNT
,
MATCH_COUNT_PRIORITY_END
};
const
std
::
set
<
std
::
string
>
kOpFormatList
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_Z
,
kOpFormat_C1HWNCoC0
,
kOpFormat_FRAC_NZ
};
bool
IsShapeMatchFormat
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
// if format is default,it remarkes support all format
if
(
kOpFormatList
.
find
(
format
)
==
kOpFormatList
.
end
())
{
MS_EXCEPTION
(
ArgumentError
)
<<
"got the unknow format "
<<
format
;
}
if
(
format
==
kOpFormat_DEFAULT
)
{
return
true
;
}
// if shape size is 0,the shape will be a scalar
if
(
shape
.
empty
())
{
return
true
;
}
if
(
shape
.
size
()
>
kShapeSupportFormatMap
.
size
())
{
return
false
;
}
if
(
format
==
kOpFormat_FRAC_NZ
&&
shape
.
size
()
>=
2
)
{
return
shape
[
shape
.
size
()
-
1
]
%
16
!=
0
&&
shape
[
shape
.
size
()
-
2
]
%
16
!=
0
;
}
return
!
(
kShapeSupportFormatMap
[
shape
.
size
()
-
1
].
find
(
format
)
==
kShapeSupportFormatMap
[
shape
.
size
()
-
1
].
end
());
}
bool
IsValidKernelInfo
(
const
std
::
shared_ptr
<
CNode
>
&
kernel_node
,
const
kernel
::
KernelBuildInfo
&
kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
auto
check_function
=
[](
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
->
bool
{
if
(
!
IsShapeMatchFormat
(
shape
,
format
))
{
return
false
;
}
for
(
auto
shape_value
:
shape
)
{
if
(
shape_value
==
0
)
{
MS_EXCEPTION
(
ArgumentError
)
<<
"dimension size of the tensor shape should be a positive integer, but got ["
<<
shape_value
<<
"]"
;
}
}
return
true
;
};
for
(
size_t
index
=
0
;
index
<
kernel_build_info
.
GetOutputNum
();
++
index
)
{
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
index
);
if
(
!
check_function
(
output_shape
,
kernel_build_info
.
GetOutputFormat
(
index
)))
{
return
false
;
}
}
for
(
size_t
index
=
0
;
index
<
kernel_build_info
.
GetInputNum
();
++
index
)
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
index
);
if
(
!
check_function
(
input_shape
,
kernel_build_info
.
GetInputFormat
(
index
)))
{
return
false
;
}
}
return
true
;
}
bool
MatchInferOutputDataType
(
const
CNodePtr
&
cnode
,
const
kernel
::
KernelBuildInfo
&
kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
// Check input data type
for
(
size_t
input_index
=
0
;
input_index
<
kernel_build_info
.
GetInputNum
();
++
input_index
)
{
AnfNodePtr
cur_input
=
cnode
->
input
(
input_index
+
1
);
MS_EXCEPTION_IF_NULL
(
cur_input
);
TypeId
input_origin_type
;
if
(
cur_input
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
IsParameterWeight
(
cur_input
->
cast
<
ParameterPtr
>
()))
{
// weight
input_origin_type
=
AnfAlgo
::
GetOutputDeviceDataType
(
cur_input
,
0
);
}
else
{
// feature map
input_origin_type
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
cnode
,
input_index
);
}
if
(
input_origin_type
==
kTypeUnknown
)
{
continue
;
}
if
(
kernel_build_info
.
GetInputDeviceType
(
input_index
)
!=
input_origin_type
)
{
return
false
;
}
}
// Check output data type
for
(
size_t
output_index
=
0
;
output_index
<
kernel_build_info
.
GetOutputNum
();
++
output_index
)
{
if
(
kernel_build_info
.
GetOutputDeviceType
(
output_index
)
!=
AnfAlgo
::
GetOutputInferDataType
(
cnode
,
output_index
))
{
return
false
;
}
}
return
true
;
}
/**
* compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if
* equal then next num location
* example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
*/
bool
PriorityChooseItem
(
const
std
::
vector
<
int
>
&
cur_item
,
std
::
vector
<
int
>
*
best_item
)
{
MS_EXCEPTION_IF_NULL
(
best_item
);
if
(
cur_item
.
size
()
!=
best_item
->
size
())
{
MS_LOG
(
ERROR
)
<<
"item size should be same!"
;
return
false
;
}
// Update the best_item by comparing the cur_item and best_item
for
(
size_t
i
=
0
;
i
<
cur_item
.
size
();
i
++
)
{
if
(
cur_item
[
i
]
>
best_item
->
at
(
i
))
{
*
best_item
=
cur_item
;
return
true
;
}
else
if
(
cur_item
[
i
]
==
best_item
->
at
(
i
))
{
continue
;
}
else
{
return
false
;
}
}
return
false
;
}
void
UpdateCurMatchCounts
(
const
kernel
::
KernelBuildInfo
&
kernel_build_info
,
const
std
::
shared_ptr
<
CNode
>
&
kernel_node
,
std
::
vector
<
int
>
*
const
cur_kernelinfo_match_counts
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
cur_kernelinfo_match_counts
);
if
(
cur_kernelinfo_match_counts
->
size
()
<
MATCH_COUNT_PRIORITY_END
)
{
MS_EXCEPTION
(
ArgumentError
)
<<
"Out of range cur_kernelinfo_match_counts "
<<
MATCH_COUNT_PRIORITY_END
;
}
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
++
input_index
)
{
AnfNodePtr
input_anf_node
=
kernel_node
->
input
(
input_index
+
1
);
MS_EXCEPTION_IF_NULL
(
input_anf_node
);
// if a input parameter is a weight with default format, the input shouldn't participate the judge
if
(
input_anf_node
->
isa
<
Parameter
>
())
{
auto
para
=
input_anf_node
->
cast
<
ParameterPtr
>
();
if
(
AnfAlgo
::
IsParameterWeight
(
para
)
&&
AnfAlgo
::
GetOutputDeviceDataType
(
para
,
0
)
==
kTypeUnknown
)
{
continue
;
}
}
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
AnfAlgo
::
GetPrevNodeOutputFormat
(
kernel_node
,
input_index
))
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_FORMAT_COUNT
]
++
;
}
if
(
kernel_build_info
.
GetInputDeviceType
(
input_index
)
==
AnfAlgo
::
GetPrevNodeOutputDeviceDataType
(
kernel_node
,
input_index
))
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_DTYPE_COUNT
]
++
;
}
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
kOpFormat_FRAC_NZ
)
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_NZ_FORMAT_COUNT
]
++
;
}
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
kOpFormat_NC1HWC0
)
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_5D_FORMAT_COUNT
]
++
;
}
}
for
(
size_t
output_index
=
0
;
output_index
<
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
++
output_index
)
{
// cal count of same output dtype between abstract and kernel info
if
(
kernel_build_info
.
GetOutputDeviceType
(
output_index
)
==
AnfAlgo
::
GetOutputInferDataType
(
kernel_node
,
output_index
))
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_OUTPUT_DTYPE_COUNT
]
++
;
}
}
}
void
SetKernelBuildInfo
(
KernelBuilderPtr
builder
)
{
builder
->
SetFusionType
(
kernel
::
OPAQUE
);
builder
->
SetKernelType
(
AUTO_DIFF_KERNEL
);
builder
->
SetProcessor
(
kernel
::
AICORE
);
}
void
test_select
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
)
{
std
::
vector
<
int
>
most_match_counts
=
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
};
int
selected_index
=
-
1
;
for
(
size_t
info_index
=
0
;
info_index
<
kernel_info_list
.
size
();
++
info_index
)
{
std
::
vector
<
int
>
cur_kernel_info_match_counts
=
{
0
,
0
,
0
,
0
,
0
};
if
(
!
IsValidKernelInfo
(
kernel_node
,
*
(
kernel_info_list
[
info_index
])))
{
continue
;
}
if
(
!
MatchInferOutputDataType
(
kernel_node
,
*
(
kernel_info_list
[
info_index
])))
{
continue
;
}
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
kernel_info_ptr
=
kernel_info_list
[
info_index
];
UpdateCurMatchCounts
(
*
kernel_info_ptr
,
kernel_node
,
&
cur_kernel_info_match_counts
);
// Currently the selection policy is the match format count first, and then is datatype counts.
if
(
PriorityChooseItem
(
cur_kernel_info_match_counts
,
&
most_match_counts
))
{
selected_index
=
SizeToInt
(
info_index
);
}
}
if
(
selected_index
==
-
1
)
{
MS_EXCEPTION
(
NotExistsError
)
<<
""
<<
kernel_node
->
DebugString
()
<<
" Cannot find valid kernel Info !"
;
}
auto
index
=
IntToSize
(
selected_index
);
if
(
index
>=
kernel_info_list
.
size
())
{
MS_EXCEPTION
(
ArgumentError
)
<<
"index outof range"
;
}
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
selected_kernel_info_ptr
=
kernel_info_list
[
index
];
MS_EXCEPTION_IF_NULL
(
selected_kernel_info_ptr
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
selected_kernel_info_ptr
,
kernel_node
.
get
());
}
void
SetParentAbstract
(
std
::
vector
<
AnfNodePtr
>
parent_list
,
std
::
vector
<
std
::
vector
<
size_t
>>
shapes
,
std
::
vector
<
TypeId
>
types
)
{
for
(
const
auto
&
node
:
parent_list
)
{
AnfAlgo
::
SetOutputInferTypeAndShape
(
types
,
shapes
,
node
.
get
());
}
}
}
// namespace
class
AscendKernelSelctTest
:
public
UT
::
Common
{
public:
AscendKernelSelctTest
()
=
default
;
void
SetUp
()
override
{}
void
TearDown
()
override
{}
};
TEST_F
(
AscendKernelSelctTest
,
TestSelect
)
{
std
::
vector
<
KernelBuilderPtr
>
build_list
;
std
::
vector
<
TypeId
>
type_list
=
{
kNumberTypeFloat32
};
for
(
size_t
i
=
0
;
i
<=
4
;
++
i
)
{
build_list
.
push_back
(
std
::
make_shared
<
KernelBuildInfoBuilder
>
());
SetKernelBuildInfo
(
build_list
[
i
]);
build_list
[
i
]
->
SetInputsDeviceType
(
type_list
);
build_list
[
i
]
->
SetOutputsDeviceType
(
type_list
);
}
std
::
vector
<
std
::
string
>
nd_fmt
=
{
kOpFormat_DEFAULT
};
std
::
vector
<
std
::
string
>
nz_fmt
=
{
kOpFormat_FRAC_NZ
};
auto
anf_graph
=
std
::
make_shared
<
KernelGraph
>
();
// 16's multiple should not chose format NZ
Shape
nd_shapes
=
{
2
,
32
,
224
,
224
};
Shape
nz_shapes
=
{
3
,
3
,
5
,
5
};
auto
add_value
=
NewValueNode
(
prim
::
kPrimTensorAdd
);
auto
a_node
=
anf_graph
->
NewCNode
(
std
::
vector
<
AnfNodePtr
>
{
add_value
});
auto
b_node
=
anf_graph
->
NewCNode
(
std
::
vector
<
AnfNodePtr
>
{
add_value
});
std
::
vector
<
AnfNodePtr
>
parent_list
=
{
add_value
,
a_node
,
b_node
};
auto
c_node
=
anf_graph
->
NewCNode
(
parent_list
);
// a b
// \ /
// c
// a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}}
// set a & b's info
SetParentAbstract
(
parent_list
,
ShapeList
{
nz_shapes
},
type_list
);
// set abstract c
AnfAlgo
::
SetOutputInferTypeAndShape
(
type_list
,
ShapeList
{
nd_shapes
},
c_node
.
get
());
// set format of kernel info
build_list
[
0
]
->
SetOutputsFormat
(
nz_fmt
);
build_list
[
1
]
->
SetOutputsFormat
(
nz_fmt
);
build_list
[
2
]
->
SetInputsFormat
(
std
::
vector
<
std
::
string
>
{
nd_fmt
[
0
],
nd_fmt
[
0
]});
build_list
[
3
]
->
SetInputsFormat
(
std
::
vector
<
std
::
string
>
{
nz_fmt
[
0
],
nz_fmt
[
0
]});
build_list
[
2
]
->
SetInputsDeviceType
(
std
::
vector
<
TypeId
>
{
kNumberTypeFloat32
,
kNumberTypeFloat32
});
build_list
[
3
]
->
SetInputsDeviceType
(
std
::
vector
<
TypeId
>
{
kNumberTypeFloat32
,
kNumberTypeFloat32
});
build_list
[
2
]
->
SetOutputsFormat
(
nd_fmt
);
build_list
[
3
]
->
SetOutputsFormat
(
nz_fmt
);
std
::
vector
<
KernelBuildInfoPtr
>
select_info_list
;
// set select info list
select_info_list
.
emplace_back
(
build_list
[
2
]
->
Build
());
select_info_list
.
emplace_back
(
build_list
[
3
]
->
Build
());
// set device info for a & b
AnfAlgo
::
SetSelectKernelBuildInfo
(
build_list
[
0
]
->
Build
(),
a_node
.
get
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
build_list
[
1
]
->
Build
(),
b_node
.
get
());
test_select
(
c_node
,
select_info_list
);
EXPECT_EQ
(
AnfAlgo
::
GetInputFormat
(
c_node
,
0
),
kOpFormat_DEFAULT
);
EXPECT_EQ
(
AnfAlgo
::
GetInputFormat
(
c_node
,
1
),
kOpFormat_DEFAULT
);
// set a & b's info
// a b
// \ /
// c
// a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}}
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// set a & b's info
SetParentAbstract
(
parent_list
,
ShapeList
{
nz_shapes
},
type_list
);
// set abstract c
AnfAlgo
::
SetOutputInferTypeAndShape
(
type_list
,
ShapeList
{
nz_shapes
},
c_node
.
get
());
// set format of kernel info
build_list
[
0
]
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_NC1HWC0
});
build_list
[
1
]
->
SetOutputsFormat
(
nz_fmt
);
build_list
[
2
]
->
SetInputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_NC1HWC0
,
nd_fmt
[
0
]});
build_list
[
3
]
->
SetInputsFormat
(
std
::
vector
<
std
::
string
>
{
nd_fmt
[
0
],
nz_fmt
[
0
]});
build_list
[
2
]
->
SetInputsDeviceType
(
std
::
vector
<
TypeId
>
{
kNumberTypeFloat32
,
kNumberTypeFloat32
});
build_list
[
3
]
->
SetInputsDeviceType
(
std
::
vector
<
TypeId
>
{
kNumberTypeFloat32
,
kNumberTypeFloat32
});
build_list
[
2
]
->
SetOutputsFormat
(
nd_fmt
);
build_list
[
3
]
->
SetOutputsFormat
(
nz_fmt
);
// set select info list
select_info_list
.
emplace_back
(
build_list
[
2
]
->
Build
());
select_info_list
.
emplace_back
(
build_list
[
3
]
->
Build
());
// set device info for a & b
AnfAlgo
::
SetSelectKernelBuildInfo
(
build_list
[
0
]
->
Build
(),
a_node
.
get
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
build_list
[
1
]
->
Build
(),
b_node
.
get
());
test_select
(
c_node
,
select_info_list
);
EXPECT_EQ
(
AnfAlgo
::
GetInputFormat
(
c_node
,
0
),
kOpFormat_DEFAULT
);
EXPECT_EQ
(
AnfAlgo
::
GetInputFormat
(
c_node
,
1
),
kOpFormat_FRAC_NZ
);
}
}
// namespace ascend
}
// namespace device
}
// namespace mindspore
\ No newline at end of file
tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
浏览文件 @
691b0648
...
...
@@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker
()
=
default
;
~
MockSupportedChecker
()
override
=
default
;
bool
CheckSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
override
{
bool
Check
AiCore
Supported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
override
{
return
true
;
}
};
// namespace opt
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
浏览文件 @
691b0648
...
...
@@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon {
UT
::
PyFuncGraphFetcher
get_py_fun_
;
};
class
MockSupportedChecker
:
public
SupportedChecker
{
public:
MockSupportedChecker
()
=
default
;
~
MockSupportedChecker
()
override
=
default
;
bool
CheckAiCoreSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
override
{
return
true
;
}
};
class
MockInsertTransOpKernelSelectTrans4Dto5D
:
public
KernelSelect
{
public:
MockInsertTransOpKernelSelectTrans4Dto5D
()
=
default
;
...
...
@@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
}
};
class
MockTransposeTransdataFusionKernelSelect
:
public
KernelSelect
{
public:
MockTransposeTransdataFusionKernelSelect
()
=
default
;
~
MockTransposeTransdataFusionKernelSelect
()
override
=
default
;
bool
CheckKernelAccuracySupported
(
const
CNodePtr
&
kernel_node
,
const
kernel
::
KernelBuildInfoPtr
&
new_kernel_build_info
)
override
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_NCHW
});
builder
.
SetOutputsFormat
({
kOpFormat_DEFAULT
});
builder
.
SetInputsDeviceType
({
kNumberTypeFloat16
});
builder
.
SetOutputsDeviceType
({
kNumberTypeFloat16
});
builder
.
SetKernelType
(
KernelType
::
AUTO_DIFF_KERNEL
);
builder
.
SetFusionType
(
kernel
::
FusionType
::
OPAQUE
);
builder
.
SetProcessor
(
kernel
::
Processor
::
AICORE
);
kernel_info_list
.
push_back
(
builder
.
Build
());
MS_LOG
(
INFO
)
<<
"transpose transdata fusion success"
;
MS_LOG
(
INFO
)
<<
"new transdata build info input format:"
<<
new_kernel_build_info
->
GetInputFormat
(
0
)
<<
",outputformat:"
<<
new_kernel_build_info
->
GetOutputFormat
(
0
)
<<
",kerneltype:"
<<
new_kernel_build_info
->
kernel_type
()
<<
",fusiontype:"
<<
new_kernel_build_info
->
fusion_type
()
<<
",process:"
<<
new_kernel_build_info
->
processor
();
auto
result
=
std
::
find_if
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
[
&
new_kernel_build_info
](
kernel
::
KernelBuildInfoPtr
item
)
{
MS_EXCEPTION_IF_NULL
(
item
);
return
*
item
==
*
new_kernel_build_info
;
});
return
result
!=
kernel_info_list
.
end
();
}
};
TEST_F
(
TestHWTransposeTransdataFusion
,
test_transpose_transdata_fusion
)
{
/*
* def before(input0, input1):
...
...
@@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
insert_trans_op_pass
->
kernel_select_
=
std
::
make_shared
<
MockInsertTransOpKernelSelectTrans4Dto5D
>
();
pm
->
AddPass
(
insert_trans_op_pass
);
auto
transpose_transdata_pass
=
std
::
make_shared
<
opt
::
TransposeTransDataFusion
>
();
transpose_transdata_pass
->
kernel_select_
=
std
::
make_shared
<
MockTransposeTransdataFusionKernelSelect
>
();
transpose_transdata_pass
->
supported_checker_
=
std
::
make_shared
<
MockSupportedChecker
>
();
pm
->
AddPass
(
transpose_transdata_pass
);
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
kg
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录