Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
338d7c1a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
338d7c1a
编写于
5月 21, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
decoupled of insert transdata and deal ref and split transdata
上级
d402b944
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
144 addition
and
132 deletion
+144
-132
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+11
-1
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+2
-2
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+5
-0
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+100
-111
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
+4
-4
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
...re_activate/ascend/format_type/deal_ref_trans_and_cast.cc
+2
-2
mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc
...e/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc
+16
-8
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+3
-3
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-1
未找到文件。
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
338d7c1a
...
@@ -503,6 +503,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
...
@@ -503,6 +503,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
KernelSelectStatus
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
)
{
KernelSelectStatus
SelectKernelInfo
(
const
CNodePtr
&
kernel_node
)
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
kernel_info_list
;
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
aicpu_kernel_info_list
;
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_node
);
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
kernel
::
KernelQuery
(
kernel_node
,
&
kernel_info_list
);
auto
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
auto
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
...
@@ -510,7 +511,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
...
@@ -510,7 +511,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if
(
select_status
==
kNoMatched
)
{
if
(
select_status
==
kNoMatched
)
{
MS_LOG
(
WARNING
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
MS_LOG
(
WARNING
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
kernel
::
AICPUQuery
(
kernel_node
,
&
kernel_info_list
);
kernel
::
AICPUQuery
(
kernel_node
,
&
aicpu_
kernel_info_list
);
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
kernel_node
);
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
kernel_node
);
}
}
...
@@ -518,6 +519,15 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
...
@@ -518,6 +519,15 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if
(
select_status
==
kNoMatched
)
{
if
(
select_status
==
kNoMatched
)
{
std
::
ostringstream
buffer
;
std
::
ostringstream
buffer
;
PrintInputAndOutputInferType
(
buffer
,
kernel_node
);
PrintInputAndOutputInferType
(
buffer
,
kernel_node
);
MS_LOG
(
WARNING
)
<<
"=========================kernel info list====================================="
;
for
(
size_t
index
=
0
;
index
<
kernel_info_list
.
size
();
++
index
)
{
MS_LOG
(
WARNING
)
<<
"kernel ["
<<
index
<<
"] :"
<<
kernel_info_list
[
index
]
->
ToString
();
}
for
(
size_t
index
=
0
;
index
<
aicpu_kernel_info_list
.
size
();
++
index
)
{
MS_LOG
(
WARNING
)
<<
"kernel ["
<<
(
kernel_info_list
.
size
()
+
index
)
<<
"] :"
<<
aicpu_kernel_info_list
[
index
]
->
ToString
();
}
MS_LOG
(
WARNING
)
<<
"========================= end ===================================="
;
MS_EXCEPTION
(
TypeError
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
MS_EXCEPTION
(
TypeError
)
<<
"The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid kernel info, not supported the type "
<<
buffer
.
str
();
<<
"] cannot find valid kernel info, not supported the type "
<<
buffer
.
str
();
}
}
...
...
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
338d7c1a
...
@@ -110,9 +110,9 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
...
@@ -110,9 +110,9 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
return
!
(
inputs_device_type_
!=
other
.
inputs_device_type_
||
outputs_device_type_
!=
other
.
outputs_device_type_
);
return
!
(
inputs_device_type_
!=
other
.
inputs_device_type_
||
outputs_device_type_
!=
other
.
outputs_device_type_
);
}
}
bool
KernelBuildInfo
::
IsInputDefaultPadding
()
const
{
return
out
put_reshape_type_
.
empty
();
}
bool
KernelBuildInfo
::
IsInputDefaultPadding
()
const
{
return
in
put_reshape_type_
.
empty
();
}
bool
KernelBuildInfo
::
IsOutputDefaultPadding
()
const
{
return
in
put_reshape_type_
.
empty
();
}
bool
KernelBuildInfo
::
IsOutputDefaultPadding
()
const
{
return
out
put_reshape_type_
.
empty
();
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetKernelType
(
const
KernelType
&
kernel_type
)
{
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetKernelType
(
const
KernelType
&
kernel_type
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
...
...
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
338d7c1a
...
@@ -56,6 +56,11 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
...
@@ -56,6 +56,11 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
TbeMetadataInfo
(
kernel_node
,
kernel_info_list
);
TbeMetadataInfo
(
kernel_node
,
kernel_info_list
);
if
(
kernel_info_list
->
empty
())
{
if
(
kernel_info_list
->
empty
())
{
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
if
(
!
kernel_info_list
->
empty
())
{
MS_LOG
(
INFO
)
<<
"Warning The node ["
<<
kernel_node
->
DebugString
()
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
kernel_node
);
}
}
}
if
(
kernel_info_list
->
empty
())
{
if
(
kernel_info_list
->
empty
())
{
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
338d7c1a
...
@@ -31,54 +31,6 @@ namespace mindspore {
...
@@ -31,54 +31,6 @@ namespace mindspore {
namespace
opt
{
namespace
opt
{
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
namespace
{
namespace
{
kernel
::
KernelBuildInfoPtr
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
AnfNodePtr
&
node
,
const
TypeId
device_type
,
const
kernel
::
KernelBuildInfo
&
ori_build_info
)
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
input_format
});
builder
.
SetOutputsFormat
({
output_format
});
builder
.
SetInputsDeviceType
({
device_type
});
builder
.
SetOutputsDeviceType
({
device_type
});
builder
.
SetKernelType
(
ori_build_info
.
kernel_type
());
builder
.
SetFusionType
(
ori_build_info
.
fusion_type
());
builder
.
SetProcessor
(
ori_build_info
.
processor
());
return
builder
.
Build
();
}
CNodePtr
NewTransOpNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input
,
const
KernelSelectPtr
&
kernel_select
,
const
bool
need_padding
,
const
std
::
string
&
op_name
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
input
);
std
::
vector
<
AnfNodePtr
>
trans_inputs
;
auto
prim
=
std
::
make_shared
<
Primitive
>
(
op_name
);
trans_inputs
.
push_back
(
NewValueNode
(
prim
));
trans_inputs
.
push_back
(
input
);
CNodePtr
trans_node
=
func_graph
->
NewCNode
(
trans_inputs
);
MS_EXCEPTION_IF_NULL
(
trans_node
);
std
::
vector
<
kernel
::
Axis
>
padding_axis
;
padding_axis
=
AnfAlgo
::
GetOutputReshapeType
(
input
,
0
);
if
(
need_padding
)
{
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
{
trans
::
PaddingShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input
,
0
),
padding_axis
)},
trans_node
.
get
());
}
else
{
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
{
AnfAlgo
::
GetOutputInferShape
(
input
,
0
)},
trans_node
.
get
());
}
// special handle for ut
if
(
trans_node
->
kernel_info
()
==
nullptr
)
{
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
trans_node
->
set_kernel_info
(
kernel_info
);
}
MS_EXCEPTION_IF_NULL
(
kernel_select
);
kernel_select
->
SelectKernel
(
trans_node
);
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
trans_node
);
MS_EXCEPTION_IF_NULL
(
trans_node
);
trans_node
->
set_scope
(
input
->
scope
());
return
trans_node
;
}
AnfNodePtr
CreateReshapeNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input_node
,
AnfNodePtr
CreateReshapeNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input_node
,
const
KernelSelectPtr
&
kernel_select
,
const
std
::
vector
<
size_t
>
&
dst_shape
)
{
const
KernelSelectPtr
&
kernel_select
,
const
std
::
vector
<
size_t
>
&
dst_shape
)
{
std
::
vector
<
AnfNodePtr
>
trans_inputs
;
std
::
vector
<
AnfNodePtr
>
trans_inputs
;
...
@@ -94,6 +46,58 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
...
@@ -94,6 +46,58 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
return
reshape
;
return
reshape
;
}
}
AnfNodePtr
AddTransOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
bool
is_insert_input
)
{
AnfNodePtr
trans_node
=
nullptr
;
AnfNodePtr
input_node
=
node
;
CNodePtr
trans_data
=
nullptr
;
std
::
string
input_format
=
is_insert_input
?
kOpFormat_DEFAULT
:
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
std
::
string
dst_format
=
is_insert_input
?
AnfAlgo
::
GetInputFormat
(
node
,
0
)
:
kOpFormat_DEFAULT
;
TypeId
dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
node
,
0
);
std
::
vector
<
kernel
::
Axis
>
padding_axis
=
AnfAlgo
::
GetOutputReshapeType
(
node
,
0
);
MS_EXCEPTION_IF_NULL
(
node
);
// if insert transdata for input we need to change the input
if
(
is_insert_input
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"cannot insert a transdata node to a node's input which the node is not a cnode"
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
dtype
=
AnfAlgo
::
GetInputDeviceDataType
(
cnode
,
insert_index
);
dst_format
=
AnfAlgo
::
GetInputFormat
(
cnode
,
insert_index
);
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
padding_axis
=
AnfAlgo
::
GetInputReshapeType
(
node
,
0
);
}
bool
need_padding
=
false
;
if
(
is_insert_input
)
{
need_padding
=
(
trans
::
IsNeedPadding
(
dst_format
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
).
size
()));
}
else
{
need_padding
=
(
trans
::
IsNeedPadding
(
input_format
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
).
size
()));
}
if
(
!
need_padding
)
{
// don't need padding insert transdata only
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
need_padding
,
prim
::
KPrimTransData
->
name
());
trans_node
=
trans_data
;
}
else
if
(
is_insert_input
)
{
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto
padding_shape
=
trans
::
PaddingShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
),
AnfAlgo
::
GetInputReshapeType
(
node
,
0
));
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
input_node
,
kernel_select
,
padding_shape
);
trans_data
=
NewTransOpNode
(
func_graph
,
reshape_node
,
kernel_select
,
need_padding
,
prim
::
KPrimTransData
->
name
());
trans_node
=
trans_data
;
}
else
{
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
need_padding
,
prim
::
KPrimTransData
->
name
());
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
trans_data
,
kernel_select
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
));
trans_node
=
reshape_node
;
}
// refresh the transdata's format to ori format & dst format
RefreshKernelBuildInfo
(
input_format
,
dst_format
,
dtype
,
trans_data
,
padding_axis
);
return
trans_node
;
}
AnfNodePtr
GetTransInputNodePtr
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
size_t
index
,
AnfNodePtr
GetTransInputNodePtr
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
size_t
index
,
const
KernelSelectPtr
&
kernel_select
)
{
const
KernelSelectPtr
&
kernel_select
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
);
...
@@ -111,13 +115,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
...
@@ -111,13 +115,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
<<
"when inserting the transdata node "
<<
node
->
DebugString
();
<<
"when inserting the transdata node "
<<
node
->
DebugString
();
}
}
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
node
,
index
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
node
,
index
);
std
::
string
origin_format
=
kOpFormat_DEFAULT
;
std
::
string
dest_format
=
AnfAlgo
::
GetInputFormat
(
node
,
index
);
std
::
string
dest_format
=
AnfAlgo
::
GetInputFormat
(
node
,
index
);
if
(
kNeedTransFormatSet
.
find
(
dest_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
if
(
kNeedTransFormatSet
.
find
(
dest_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
MS_LOG
(
DEBUG
)
<<
node
->
DebugString
()
<<
"Insert transdata "
<<
AnfAlgo
::
GetInputFormat
(
node
,
index
)
MS_LOG
(
DEBUG
)
<<
node
->
DebugString
()
<<
"Insert transdata "
<<
AnfAlgo
::
GetInputFormat
(
node
,
index
)
<<
" To DefaultFormat , index: "
<<
index
;
<<
" To DefaultFormat , index: "
<<
index
;
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
origin_format
,
dest_format
,
kTransDataOpName
,
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
true
);
true
);
}
}
return
input_node
;
return
input_node
;
}
}
...
@@ -131,12 +133,9 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
...
@@ -131,12 +133,9 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
MS_LOG
(
EXCEPTION
)
<<
"got the hw format "
<<
output_format
<<
"when insert the transdata node "
MS_LOG
(
EXCEPTION
)
<<
"got the hw format "
<<
output_format
<<
"when insert the transdata node "
<<
node
->
DebugString
();
<<
node
->
DebugString
();
}
}
std
::
string
origin_format
=
output_format
;
std
::
string
dest_format
=
kOpFormat_DEFAULT
;
if
(
kNeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
if
(
kNeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
MS_LOG
(
DEBUG
)
<<
"Inserted Transdata "
<<
output_format
<<
" To default , index :0"
;
MS_LOG
(
DEBUG
)
<<
"Inserted Transdata "
<<
output_format
<<
" To default , index :0"
;
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
origin_format
,
dest_format
,
kTransDataOpName
,
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
false
);
false
);
}
}
return
node
;
return
node
;
}
}
...
@@ -155,10 +154,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
...
@@ -155,10 +154,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
}
auto
tuple_getitem
=
CreatTupleGetItemNode
(
func_graph
,
node
,
output_idx
);
auto
tuple_getitem
=
CreatTupleGetItemNode
(
func_graph
,
node
,
output_idx
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_idx
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_idx
);
std
::
string
dest_format
=
kOpFormat_DEFAULT
;
if
(
kNeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
if
(
kNeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
make_tuple_inputs
.
emplace_back
(
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
output_format
,
make_tuple_inputs
.
emplace_back
(
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
false
));
dest_format
,
kTransDataOpName
,
false
));
}
else
{
}
else
{
// No need insert trans op.
// No need insert trans op.
make_tuple_inputs
.
push_back
(
tuple_getitem
);
make_tuple_inputs
.
push_back
(
tuple_getitem
);
...
@@ -168,62 +165,54 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
...
@@ -168,62 +165,54 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
return
make_tuple
;
return
make_tuple
;
}
}
}
// namespace
}
// namespace
AnfNodePtr
AddTransOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
void
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
TypeId
device_type
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
const
AnfNodePtr
&
trans_data
,
const
std
::
vector
<
kernel
::
Axis
>
&
reshape_type
)
{
const
std
::
string
&
origin_format
,
const
std
::
string
&
dest_format
,
MS_EXCEPTION_IF_NULL
(
trans_data
);
const
std
::
string
&
op_name
,
bool
is_insert_input
)
{
MS_EXCEPTION_IF_NULL
(
trans_data
->
kernel_info
());
AnfNodePtr
trans_node
=
nullptr
;
auto
ori_build_info
=
trans_data
->
kernel_info
()
->
select_kernel_build_info
();
AnfNodePtr
input_node
=
node
;
KernelBuildInfoBuilder
builder
;
AnfNodePtr
trans_data
=
nullptr
;
builder
.
SetInputsFormat
({
input_format
});
TypeId
dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
node
,
0
);
builder
.
SetInputReshapeType
({
reshape_type
});
MS_EXCEPTION_IF_NULL
(
node
);
builder
.
SetInputReshapeType
({
reshape_type
});
if
(
origin_format
.
empty
()
||
dest_format
.
empty
())
{
builder
.
SetOutputsFormat
({
output_format
});
MS_LOG
(
EXCEPTION
)
<<
"trans op format is error, origin = "
<<
origin_format
<<
", dest "
<<
origin_format
;
builder
.
SetInputsDeviceType
({
device_type
});
}
builder
.
SetOutputsDeviceType
({
device_type
});
// if insert transdata for input we need to change the input
builder
.
SetKernelType
(
ori_build_info
->
kernel_type
());
if
(
is_insert_input
)
{
builder
.
SetFusionType
(
ori_build_info
->
fusion_type
());
if
(
!
node
->
isa
<
CNode
>
())
{
builder
.
SetProcessor
(
ori_build_info
->
processor
());
MS_LOG
(
EXCEPTION
)
<<
"cannot insert a transdata node to a node's input which the node is not a cnode"
;
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
.
Build
(),
trans_data
.
get
());
}
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
dtype
=
AnfAlgo
::
GetInputDeviceDataType
(
cnode
,
insert_index
);
CNodePtr
NewTransOpNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input
,
const
KernelSelectPtr
&
kernel_select
,
MS_EXCEPTION_IF_NULL
(
cnode
);
const
bool
need_padding
,
const
std
::
string
&
op_name
)
{
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
MS_EXCEPTION_IF_NULL
(
func_graph
);
}
MS_EXCEPTION_IF_NULL
(
input
);
bool
need_padding
=
false
;
std
::
vector
<
AnfNodePtr
>
trans_inputs
;
if
(
is_insert_input
)
{
auto
prim
=
std
::
make_shared
<
Primitive
>
(
op_name
);
need_padding
=
(
trans
::
IsNeedPadding
(
dest_format
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
).
size
())
&&
trans_inputs
.
push_back
(
NewValueNode
(
prim
));
op_name
==
kTransDataOpName
);
trans_inputs
.
push_back
(
input
);
CNodePtr
trans_node
=
func_graph
->
NewCNode
(
trans_inputs
);
MS_EXCEPTION_IF_NULL
(
trans_node
);
auto
padding_axis
=
AnfAlgo
::
GetOutputReshapeType
(
input
,
0
);
if
(
need_padding
)
{
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
{
trans
::
PaddingShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input
,
0
),
padding_axis
)},
trans_node
.
get
());
}
else
{
}
else
{
need_padding
=
(
trans
::
IsNeedPadding
(
origin_format
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
).
size
())
&&
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
op_name
==
kTransDataOpName
);
{
AnfAlgo
::
GetOutputInferShape
(
input
,
0
)},
trans_node
.
get
()
);
}
}
if
(
!
need_padding
)
{
// special handle for ut
// don't need padding insert transdata only
if
(
trans_node
->
kernel_info
()
==
nullptr
)
{
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
need_padding
,
op_name
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
trans_node
=
trans_data
;
trans_node
->
set_kernel_info
(
kernel_info
);
}
else
if
(
is_insert_input
)
{
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto
padding_shape
=
trans
::
PaddingShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
),
AnfAlgo
::
GetInputReshapeType
(
node
,
0
));
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
input_node
,
kernel_select
,
padding_shape
);
trans_data
=
NewTransOpNode
(
func_graph
,
reshape_node
,
kernel_select
,
need_padding
,
op_name
);
trans_node
=
trans_data
;
}
else
{
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
need_padding
,
op_name
);
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
trans_data
,
kernel_select
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
));
trans_node
=
reshape_node
;
}
}
// refresh the transdata's format to ori format & dst format
MS_EXCEPTION_IF_NULL
(
kernel_select
);
MS_EXCEPTION_IF_NULL
(
trans_data
);
kernel_select
->
SelectKernel
(
trans_node
);
MS_EXCEPTION_IF_NULL
(
trans_data
->
kernel_info
());
AnfAlgo
::
SetNodeAttr
(
kAttrVisited
,
MakeValue
(
true
),
trans_node
);
auto
trans_ori_build_info
=
trans_data
->
kernel_info
()
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
trans_node
);
auto
kernel_build_info
=
RefreshKernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
dtype
,
*
trans_ori_build_info
);
trans_node
->
set_scope
(
input
->
scope
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
trans_data
.
get
());
return
trans_node
;
return
trans_node
;
}
}
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
浏览文件 @
338d7c1a
...
@@ -58,11 +58,11 @@ class KernelQuery {
...
@@ -58,11 +58,11 @@ class KernelQuery {
}
}
};
};
using
KernelQueryPtr
=
std
::
shared_ptr
<
KernelQuery
>
;
using
KernelQueryPtr
=
std
::
shared_ptr
<
KernelQuery
>
;
void
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
TypeId
device_type
,
const
AnfNodePtr
&
trans_data
,
const
std
::
vector
<
kernel
::
Axis
>
&
reshape_type
=
{});
AnfNodePtr
AddTransOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
CNodePtr
NewTransOpNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input
,
const
KernelSelectPtr
&
kernel_select
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
const
bool
need_padding
,
const
std
::
string
&
op_name
);
const
std
::
string
&
origin_format
,
const
std
::
string
&
dest_format
,
const
std
::
string
&
op_name
,
bool
is_insert_input
);
AnfNodePtr
AddCastOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input
,
const
std
::
string
&
format
,
AnfNodePtr
AddCastOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input
,
const
std
::
string
&
format
,
const
TypeId
&
input_type
,
const
TypeId
&
output_type
,
const
TypeId
&
input_type
,
const
TypeId
&
output_type
,
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
浏览文件 @
338d7c1a
...
@@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
...
@@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
// insert trans
// insert trans
if
(
origin_format
!=
cur_format
&&
cur_shape
.
size
()
>
1
)
{
if
(
origin_format
!=
cur_format
&&
cur_shape
.
size
()
>
1
)
{
auto
kernel_select
=
std
::
make_shared
<
KernelSelect
>
();
auto
kernel_select
=
std
::
make_shared
<
KernelSelect
>
();
final_node
=
AddTransOpNodeToGraph
(
func_graph
,
final_node
,
kernel_select
,
0
,
cur_format
,
origin_format
,
final_node
=
NewTransOpNode
(
func_graph
,
final_node
,
kernel_select
,
false
,
prim
::
KPrimTransData
->
name
());
kTransDataOpName
,
fals
e
);
RefreshKernelBuildInfo
(
cur_format
,
origin_format
,
origin_type
,
final_nod
e
);
final_index
=
0
;
final_index
=
0
;
MS_EXCEPTION_IF_NULL
(
final_node
);
MS_EXCEPTION_IF_NULL
(
final_node
);
MS_LOG
(
INFO
)
<<
"DealRefTransAndCast add trans op, op debug info is "
<<
final_node
->
DebugString
();
MS_LOG
(
INFO
)
<<
"DealRefTransAndCast add trans op, op debug info is "
<<
final_node
->
DebugString
();
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc
浏览文件 @
338d7c1a
...
@@ -67,22 +67,30 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
...
@@ -67,22 +67,30 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
// if output_format=default transdata need split transdata->transpose else transpose->transdata
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if
(
output_format
==
kOpFormat_DEFAULT
||
output_format
==
kOpFormat_NCHW
)
{
if
(
output_format
==
kOpFormat_DEFAULT
||
output_format
==
kOpFormat_NCHW
)
{
// trans input_format to hwcn
// trans input_format to hwcn
new_transdata_node
=
new_transdata_node
=
NewTransOpNode
(
func_graph
,
AnfAlgo
::
GetInputNode
(
node
->
cast
<
CNodePtr
>
(),
0
),
kernel_select_
,
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select_
,
0
,
input_format
,
kOpFormat_HWCN
,
kTransDataOpName
,
true
);
false
,
prim
::
KPrimTransData
->
name
());
RefreshKernelBuildInfo
(
input_format
,
kOpFormat_HWCN
,
AnfAlgo
::
GetOutputDeviceDataType
(
new_transdata_node
,
0
),
new_transdata_node
);
// trans hwcn to default_format
// trans hwcn to default_format
new_transpose_node
=
AddTransOpNodeToGraph
(
func_graph
,
new_transdata_node
,
kernel_select_
,
0
,
kOpFormat_HWCN
,
new_transpose_node
=
output_format
,
prim
::
kPrimTranspose
->
name
(),
false
);
NewTransOpNode
(
func_graph
,
new_transdata_node
,
kernel_select_
,
false
,
prim
::
kPrimTranspose
->
name
());
RefreshKernelBuildInfo
(
kOpFormat_HWCN
,
output_format
,
AnfAlgo
::
GetOutputDeviceDataType
(
new_transpose_node
,
0
),
new_transpose_node
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
3
,
2
,
0
,
1
}),
new_transpose_node
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
3
,
2
,
0
,
1
}),
new_transpose_node
);
new_replace_node
=
new_transpose_node
;
new_replace_node
=
new_transpose_node
;
}
else
{
}
else
{
// trans default to hwcn
// trans default to hwcn
new_transpose_node
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select_
,
0
,
input_format
,
kOpFormat_HWCN
,
new_transpose_node
=
NewTransOpNode
(
func_graph
,
AnfAlgo
::
GetInputNode
(
node
->
cast
<
CNodePtr
>
(),
0
),
kernel_select_
,
prim
::
kPrimTranspose
->
name
(),
true
);
false
,
prim
::
kPrimTranspose
->
name
()
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
2
,
3
,
1
,
0
}),
new_transpose_node
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
2
,
3
,
1
,
0
}),
new_transpose_node
);
RefreshKernelBuildInfo
(
input_format
,
kOpFormat_HWCN
,
AnfAlgo
::
GetOutputDeviceDataType
(
new_transpose_node
,
0
),
new_transpose_node
);
// trans hwcn to output_format
// trans hwcn to output_format
new_transdata_node
=
AddTransOpNodeToGraph
(
func_graph
,
new_transpose_node
,
kernel_select_
,
0
,
kOpFormat_HWCN
,
new_transdata_node
=
output_format
,
kTransDataOpName
,
false
);
NewTransOpNode
(
func_graph
,
new_transpose_node
,
kernel_select_
,
false
,
prim
::
KPrimTransData
->
name
());
RefreshKernelBuildInfo
(
kOpFormat_HWCN
,
output_format
,
AnfAlgo
::
GetOutputDeviceDataType
(
new_transdata_node
,
0
),
new_transpose_node
);
new_replace_node
=
new_transdata_node
;
new_replace_node
=
new_transdata_node
;
}
}
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
338d7c1a
...
@@ -196,10 +196,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
...
@@ -196,10 +196,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
}
}
if
(
inputs
.
size
()
==
1
||
!
feature_map_input_indexs
.
empty
())
{
if
(
inputs
.
size
()
==
1
||
!
feature_map_input_indexs
.
empty
())
{
kernel_info
->
SetFeatureMapFlag
(
true
);
kernel_info
->
SetFeatureMapFlag
(
true
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
true
),
cnode
);
}
if
(
AnfAlgo
::
IsRealCNodeKernel
(
cnode
))
{
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
kernel_info
->
is_feature_map
()),
cnode
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapInputList
,
MakeValue
(
feature_map_input_indexs
),
cnode
);
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapInputList
,
MakeValue
(
feature_map_input_indexs
),
cnode
);
}
else
{
AnfAlgo
::
SetNodeAttr
(
kIsFeatureMapOutput
,
MakeValue
(
false
),
cnode
);
}
}
cnode
->
set_kernel_info
(
kernel_info
);
cnode
->
set_kernel_info
(
kernel_info
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
cnode
.
get
());
AnfAlgo
::
SetGraphId
(
graph_id_
,
cnode
.
get
());
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
338d7c1a
...
@@ -151,7 +151,7 @@ constexpr auto kSquareSumAllOpName = "SquareSumAll";
...
@@ -151,7 +151,7 @@ constexpr auto kSquareSumAllOpName = "SquareSumAll";
// attr key name
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
constexpr
auto
kAttrInputNames
=
"input_names"
;
constexpr
auto
kAttrIsAICPUKernel
=
"is_
ai_cpu
_kernel"
;
constexpr
auto
kAttrIsAICPUKernel
=
"is_
AICPU
_kernel"
;
constexpr
auto
kIsBackendCast
=
"is_backed_cast"
;
constexpr
auto
kIsBackendCast
=
"is_backed_cast"
;
constexpr
auto
kAttrOutputNames
=
"output_names"
;
constexpr
auto
kAttrOutputNames
=
"output_names"
;
constexpr
auto
kAttrVisited
=
"visited"
;
constexpr
auto
kAttrVisited
=
"visited"
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录