Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
52e97dbb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
52e97dbb
编写于
5月 15, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
using device dtype to create transdata kernel build info
上级
94883f9b
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
60 addition
and
42 deletion
+60
-42
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+1
-1
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+4
-5
mindspore/ccsrc/kernel/kernel_query.h
mindspore/ccsrc/kernel/kernel_query.h
+3
-3
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
+18
-8
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h
+1
-1
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+10
-19
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
+2
-2
mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc
...activate/ascend/ir_fusion/parameter_and_transop_fusion.cc
+1
-1
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+20
-2
未找到文件。
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
52e97dbb
...
@@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
...
@@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if
(
select_status
==
kNoMatched
)
{
if
(
select_status
==
kNoMatched
)
{
MS_LOG
(
WARNING
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
MS_LOG
(
WARNING
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
kernel
::
AIC
pu
Query
(
kernel_node
,
&
kernel_info_list
);
kernel
::
AIC
PU
Query
(
kernel_node
,
&
kernel_info_list
);
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
kernel_node
);
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
kernel_node
);
}
}
...
...
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
52e97dbb
...
@@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
...
@@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
}
}
void
AIC
pu
Query
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
void
AIC
PU
Query
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
kernel_info_list
->
clear
();
kernel_info_list
->
clear
();
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
}
}
bool
IsSupportedByA
iCpu
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
bool
IsSupportedByA
ICPU
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
select_kernel_build_info
);
MS_EXCEPTION_IF_NULL
(
select_kernel_build_info
);
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
auto
cnode
=
kernel_node
->
cast
<
CNodePtr
>
();
auto
cnode
=
kernel_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
);
AicpuMetadataInfo
(
cnode
,
&
kernel_info_list
);
AICPUQuery
(
cnode
,
&
kernel_info_list
);
FilterInvalidKernelInfo
(
cnode
,
&
kernel_info_list
);
return
std
::
any_of
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
return
std
::
any_of
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
[
&
select_kernel_build_info
](
const
kernel
::
KernelBuildInfoPtr
item
)
{
[
&
select_kernel_build_info
](
const
kernel
::
KernelBuildInfoPtr
item
)
{
MS_EXCEPTION_IF_NULL
(
item
);
MS_EXCEPTION_IF_NULL
(
item
);
...
@@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
...
@@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
});
});
}
}
bool
IsSupportedByA
i
Core
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
bool
IsSupportedByA
I
Core
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
select_kernel_build_info
);
MS_EXCEPTION_IF_NULL
(
select_kernel_build_info
);
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
...
...
mindspore/ccsrc/kernel/kernel_query.h
浏览文件 @
52e97dbb
...
@@ -26,9 +26,9 @@
...
@@ -26,9 +26,9 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
void
KernelQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
void
KernelQuery
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
void
AIC
pu
Query
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
void
AIC
PU
Query
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
*
kernel_info_list
);
bool
IsSupportedByA
iCpu
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
bool
IsSupportedByA
ICPU
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
bool
IsSupportedByA
i
Core
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
bool
IsSupportedByA
I
Core
(
const
AnfNodePtr
&
kernel_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
浏览文件 @
52e97dbb
...
@@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
...
@@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
if
(
format
==
kOpFormat_DEFAULT
)
{
if
(
format
==
kOpFormat_DEFAULT
)
{
return
true
;
return
true
;
}
}
if
(
format
==
kOpFormat_NDHWC
&&
shape
.
size
()
!=
kShape5dDims
)
{
return
false
;
}
// if shape size is 0, the shape will be a scalar
// if shape size is 0, the shape will be a scalar
if
(
shape
.
empty
())
{
if
(
shape
.
empty
())
{
return
true
;
return
true
;
...
@@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
...
@@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
bool
IsValidKernelInfo
(
const
std
::
shared_ptr
<
CNode
>
&
kernel_node
,
const
kernel
::
KernelBuildInfo
&
kernel_build_info
)
{
bool
IsValidKernelInfo
(
const
std
::
shared_ptr
<
CNode
>
&
kernel_node
,
const
kernel
::
KernelBuildInfo
&
kernel_build_info
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_node
);
auto
check_function
=
[](
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
->
bool
{
const
size_t
kCAxis
=
1
;
if
(
!
IsShapeMatchFormat
(
shape
,
format
))
{
return
false
;
}
return
true
;
};
for
(
size_t
index
=
0
;
index
<
kernel_build_info
.
GetOutputNum
();
++
index
)
{
for
(
size_t
index
=
0
;
index
<
kernel_build_info
.
GetOutputNum
();
++
index
)
{
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
index
);
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
index
);
if
(
!
check_function
(
output_shape
,
kernel_build_info
.
GetOutputFormat
(
index
)))
{
if
(
kernel_build_info
.
GetOutputFormat
(
index
)
==
kOpFormat_FRACTAL_Z_C04
)
{
if
(
output_shape
.
size
()
!=
kShape4dDims
||
output_shape
[
kCAxis
]
>
4
)
{
return
false
;
}
return
false
;
}
if
(
!
IsShapeMatchFormat
(
output_shape
,
kernel_build_info
.
GetOutputFormat
(
index
)))
{
return
false
;
return
false
;
}
}
}
}
for
(
size_t
index
=
0
;
index
<
kernel_build_info
.
GetInputNum
();
++
index
)
{
for
(
size_t
index
=
0
;
index
<
kernel_build_info
.
GetInputNum
();
++
index
)
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
index
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
index
);
if
(
!
check_function
(
input_shape
,
kernel_build_info
.
GetInputFormat
(
index
)))
{
if
(
!
IsShapeMatchFormat
(
input_shape
,
kernel_build_info
.
GetInputFormat
(
index
)))
{
return
false
;
}
if
(
kernel_build_info
.
GetInputFormat
(
index
)
==
kOpFormat_FRACTAL_Z_C04
)
{
if
(
input_shape
.
size
()
!=
kShape4dDims
||
input_shape
[
kCAxis
]
>
4
)
{
return
false
;
}
return
false
;
return
false
;
}
}
}
}
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h
浏览文件 @
52e97dbb
...
@@ -20,12 +20,12 @@
...
@@ -20,12 +20,12 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include <memory>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
#include "kernel/kernel_build_info.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
void
TbeMetadataInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>
*
kernel_info_list
);
void
TbeMetadataInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>
*
kernel_info_list
);
bool
CheckSupported
(
const
AnfNodePtr
&
anf_node
,
const
KernelBuildInfoPtr
&
select_kernel_build_info
);
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
52e97dbb
...
@@ -32,13 +32,13 @@ namespace opt {
...
@@ -32,13 +32,13 @@ namespace opt {
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
namespace
{
namespace
{
kernel
::
KernelBuildInfoPtr
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
kernel
::
KernelBuildInfoPtr
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
AnfNodePtr
&
node
,
const
AnfNodePtr
&
node
,
const
TypeId
device_type
,
const
kernel
::
KernelBuildInfo
ori_build_info
)
{
const
kernel
::
KernelBuildInfo
&
ori_build_info
)
{
KernelBuildInfoBuilder
builder
;
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
input_format
});
builder
.
SetInputsFormat
({
input_format
});
builder
.
SetOutputsFormat
({
output_format
});
builder
.
SetOutputsFormat
({
output_format
});
builder
.
SetInputsDeviceType
({
ori_build_info
.
GetInputDeviceType
(
0
)
});
builder
.
SetInputsDeviceType
({
device_type
});
builder
.
SetOutputsDeviceType
({
ori_build_info
.
GetOutputDeviceType
(
0
)
});
builder
.
SetOutputsDeviceType
({
device_type
});
builder
.
SetKernelType
(
ori_build_info
.
kernel_type
());
builder
.
SetKernelType
(
ori_build_info
.
kernel_type
());
builder
.
SetFusionType
(
ori_build_info
.
fusion_type
());
builder
.
SetFusionType
(
ori_build_info
.
fusion_type
());
builder
.
SetProcessor
(
ori_build_info
.
processor
());
builder
.
SetProcessor
(
ori_build_info
.
processor
());
...
@@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
...
@@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr
trans_node
=
func_graph
->
NewCNode
(
trans_inputs
);
CNodePtr
trans_node
=
func_graph
->
NewCNode
(
trans_inputs
);
MS_EXCEPTION_IF_NULL
(
trans_node
);
MS_EXCEPTION_IF_NULL
(
trans_node
);
std
::
vector
<
kernel
::
Axis
>
padding_axis
;
std
::
vector
<
kernel
::
Axis
>
padding_axis
;
if
(
AnfAlgo
::
IsRealKernel
(
input
))
{
padding_axis
=
AnfAlgo
::
GetOutputReshapeType
(
input
,
0
);
padding_axis
=
AnfAlgo
::
GetOutputReshapeType
(
input
,
0
);
}
else
{
padding_axis
=
AnfAlgo
::
GetPrevNodeOutputReshapeType
(
input
,
0
);
}
if
(
need_padding
)
{
if
(
need_padding
)
{
// if need padding we should set the transdata node's shape to the padding shape
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
...
@@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
...
@@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
AnfNodePtr
InsertTransOpForSingleOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
AnfNodePtr
InsertTransOpForSingleOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
KernelSelectPtr
&
kernel_select
)
{
const
KernelSelectPtr
&
kernel_select
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
std
::
string
output_format
;
std
::
string
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
std
::
vector
<
size_t
>
origin_shape
;
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
0
);
if
(
!
AnfAlgo
::
IsRealKernel
(
node
))
{
output_format
=
AnfAlgo
::
GetPrevNodeOutputFormat
(
node
,
0
);
origin_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
node
,
0
);
}
else
{
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
0
);
}
if
(
output_format
==
kOpFormat_NC1KHKWHWC0
)
{
if
(
output_format
==
kOpFormat_NC1KHKWHWC0
)
{
MS_LOG
(
EXCEPTION
)
<<
"got the hw format "
<<
output_format
<<
"when insert the transdata node "
MS_LOG
(
EXCEPTION
)
<<
"got the hw format "
<<
output_format
<<
"when insert the transdata node "
<<
node
->
DebugString
();
<<
node
->
DebugString
();
...
@@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
...
@@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
AnfNodePtr
trans_node
=
nullptr
;
AnfNodePtr
trans_node
=
nullptr
;
AnfNodePtr
input_node
=
node
;
AnfNodePtr
input_node
=
node
;
AnfNodePtr
trans_data
=
nullptr
;
AnfNodePtr
trans_data
=
nullptr
;
TypeId
dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
node
,
0
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
origin_format
.
empty
()
||
dest_format
.
empty
())
{
if
(
origin_format
.
empty
()
||
dest_format
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"trans op format is error, origin = "
<<
origin_format
<<
", dest "
<<
origin_format
;
MS_LOG
(
EXCEPTION
)
<<
"trans op format is error, origin = "
<<
origin_format
<<
", dest "
<<
origin_format
;
...
@@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
...
@@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_LOG
(
EXCEPTION
)
<<
"cannot insert a transdata node to a node's input which the node is not a cnode"
;
MS_LOG
(
EXCEPTION
)
<<
"cannot insert a transdata node to a node's input which the node is not a cnode"
;
}
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
dtype
=
AnfAlgo
::
GetInputDeviceDataType
(
cnode
,
insert_index
);
MS_EXCEPTION_IF_NULL
(
cnode
);
MS_EXCEPTION_IF_NULL
(
cnode
);
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
}
}
...
@@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
...
@@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL
(
trans_data
);
MS_EXCEPTION_IF_NULL
(
trans_data
);
MS_EXCEPTION_IF_NULL
(
trans_data
->
kernel_info
());
MS_EXCEPTION_IF_NULL
(
trans_data
->
kernel_info
());
auto
trans_ori_build_info
=
trans_data
->
kernel_info
()
->
select_kernel_build_info
();
auto
trans_ori_build_info
=
trans_data
->
kernel_info
()
->
select_kernel_build_info
();
auto
kernel_build_info
=
RefreshKernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
*
trans_ori_build_info
);
auto
kernel_build_info
=
RefreshKernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
dtype
,
*
trans_ori_build_info
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
trans_data
.
get
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
trans_data
.
get
());
return
trans_node
;
return
trans_node
;
}
}
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
浏览文件 @
52e97dbb
...
@@ -39,11 +39,11 @@ class SupportedChecker {
...
@@ -39,11 +39,11 @@ class SupportedChecker {
virtual
~
SupportedChecker
()
=
default
;
virtual
~
SupportedChecker
()
=
default
;
virtual
bool
CheckAiCoreSupported
(
const
AnfNodePtr
&
anf_node
,
virtual
bool
CheckAiCoreSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
return
kernel
::
IsSupportedByA
i
Core
(
anf_node
,
select_kernel_build_info
);
return
kernel
::
IsSupportedByA
I
Core
(
anf_node
,
select_kernel_build_info
);
}
}
virtual
bool
CheckAiCpuSupported
(
const
AnfNodePtr
&
anf_node
,
virtual
bool
CheckAiCpuSupported
(
const
AnfNodePtr
&
anf_node
,
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
return
kernel
::
IsSupportedByA
iCpu
(
anf_node
,
select_kernel_build_info
);
return
kernel
::
IsSupportedByA
ICPU
(
anf_node
,
select_kernel_build_info
);
}
}
};
};
using
SupportedCheckerPtr
=
std
::
shared_ptr
<
SupportedChecker
>
;
using
SupportedCheckerPtr
=
std
::
shared_ptr
<
SupportedChecker
>
;
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc
浏览文件 @
52e97dbb
...
@@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
...
@@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
auto
param_dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
final_node
,
0
);
auto
param_dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
final_node
,
0
);
auto
cast
=
trans_road
[
1
];
auto
cast
=
trans_road
[
1
];
AnfAlgo
::
SetSelectKernelBuildInfo
(
GetKernelBuildInfo
(
cast
,
format
,
param_dtype
,
dtype
),
cast
.
get
());
if
(
param_format
==
format
&&
param_dtype
!=
dtype
)
{
if
(
param_format
==
format
&&
param_dtype
!=
dtype
)
{
AnfAlgo
::
SetSelectKernelBuildInfo
(
GetKernelBuildInfo
(
cast
,
format
,
param_dtype
,
dtype
),
cast
.
get
());
manager
->
Replace
(
trans_road
[
2
],
final_node
);
manager
->
Replace
(
trans_road
[
2
],
final_node
);
manager
->
Replace
(
cur_transop
,
cast
);
manager
->
Replace
(
cur_transop
,
cast
);
}
}
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
52e97dbb
...
@@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
...
@@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
<<
" is out of the node output range :"
<<
GetOutputTensorNum
(
node
)
<<
" #node ["
<<
" is out of the node output range :"
<<
GetOutputTensorNum
(
node
)
<<
" #node ["
<<
node
->
DebugString
()
<<
"]"
;
<<
node
->
DebugString
()
<<
"]"
;
}
}
if
(
!
AnfAlgo
::
IsRealKernel
(
node
))
{
return
AnfAlgo
::
GetPrevNodeOutputFormat
(
node
,
output_idx
);
}
auto
kernel_info
=
node
->
kernel_info
();
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
@@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
...
@@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
<<
" is out of the number node Input range :"
<<
GetInputTensorNum
(
node
)
<<
"#node ["
<<
" is out of the number node Input range :"
<<
GetInputTensorNum
(
node
)
<<
"#node ["
<<
node
->
DebugString
()
<<
"]"
;
<<
node
->
DebugString
()
<<
"]"
;
}
}
if
(
!
IsRealKernel
(
node
))
{
GetPrevNodeOutputFormat
(
node
,
input_idx
);
}
auto
kernel_info
=
node
->
kernel_info
();
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
@@ -367,8 +373,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
...
@@ -367,8 +373,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
}
else
if
(
b_shp
->
isa
<
abstract
::
NoShape
>
())
{
}
else
if
(
b_shp
->
isa
<
abstract
::
NoShape
>
())
{
return
std
::
vector
<
size_t
>
();
return
std
::
vector
<
size_t
>
();
}
else
{
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output type of ApplyKernel
should be a NoShape , ArrayShape or a TupleShape, but it is "
MS_LOG
(
EXCEPTION
)
<<
"The output type of ApplyKernel
index:"
<<
output_idx
<<
base_shape
->
ToString
();
<<
" should be a NoShape , ArrayShape or a TupleShape, but it is "
<<
base_shape
->
ToString
();
}
}
}
else
if
(
base_shape
->
isa
<
abstract
::
NoShape
>
())
{
}
else
if
(
base_shape
->
isa
<
abstract
::
NoShape
>
())
{
return
std
::
vector
<
size_t
>
();
return
std
::
vector
<
size_t
>
();
...
@@ -415,6 +421,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
...
@@ -415,6 +421,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
<<
" is out of range of the node's input size : "
<<
GetInputTensorNum
(
node
)
<<
"#node["
<<
" is out of range of the node's input size : "
<<
GetInputTensorNum
(
node
)
<<
"#node["
<<
node
->
DebugString
()
<<
"]"
;
<<
node
->
DebugString
()
<<
"]"
;
}
}
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputReshapeType
(
node
,
input_idx
);
}
auto
kernel_info
=
node
->
kernel_info
();
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
@@ -431,6 +440,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
...
@@ -431,6 +440,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node[ "
<<
node
->
DebugString
()
<<
"]"
;
<<
GetOutputTensorNum
(
node
)
<<
"#node[ "
<<
node
->
DebugString
()
<<
"]"
;
}
}
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputReshapeType
(
node
,
output_idx
);
}
auto
kernel_info
=
node
->
kernel_info
();
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
@@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
...
@@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node [ "
<<
node
->
DebugString
()
<<
"]"
;
<<
GetOutputTensorNum
(
node
)
<<
"#node [ "
<<
node
->
DebugString
()
<<
"]"
;
}
}
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputDeviceDataType
(
node
,
output_idx
);
}
auto
kernel_info
=
node
->
kernel_info
();
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
@@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
...
@@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
input_idx
<<
"] is out of range of the node's input size [ "
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
input_idx
<<
"] is out of range of the node's input size [ "
<<
GetInputTensorNum
(
node
)
<<
"#node [ "
<<
node
->
DebugString
()
<<
"]"
;
<<
GetInputTensorNum
(
node
)
<<
"#node [ "
<<
node
->
DebugString
()
<<
"]"
;
}
}
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputDeviceDataType
(
node
,
0
);
}
auto
kernel_info
=
node
->
kernel_info
();
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录