Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5d25bf7c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d25bf7c
编写于
6月 17, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more transform format insert transdata
上级
b9e59f9d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
11 addition
and
12 deletion
+11
-12
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+1
-1
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+4
-7
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
...ctivate/ascend/format_type/rectify_do_mask_kernel_info.cc
+3
-3
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-1
mindspore/ops/_op_impl/tbe/trans_data.py
mindspore/ops/_op_impl/tbe/trans_data.py
+2
-0
未找到文件。
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
5d25bf7c
...
...
@@ -70,7 +70,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
for
(
size_t
index
=
0
;
index
<
AnfAlgo
::
GetInputTensorNum
(
cnode
);
++
index
)
{
auto
pre_output_format
=
AnfAlgo
::
GetPrevNodeOutputFormat
(
cnode
,
index
);
if
(
AnfAlgo
::
IsFeatureMapInput
(
cnode
,
index
)
&&
k
NeedTransFormatSet
.
find
(
pre_output_format
)
!=
kNeedTrans
FormatSet
.
end
())
{
k
HWSpecialFormatSet
.
find
(
pre_output_format
)
!=
kHWSpecial
FormatSet
.
end
())
{
priority_matched_format
=
!
is_init
?
pre_output_format
:
priority_matched_format
;
is_init
=
true
;
}
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
5d25bf7c
...
...
@@ -31,6 +31,7 @@ namespace mindspore {
namespace
opt
{
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
namespace
{
const
std
::
set
<
std
::
string
>
kCommonFormatSet
=
{
kOpFormat_DEFAULT
,
kOpFormat_ND
,
kOpFormat_NCHW
};
AnfNodePtr
CreateReshapeNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
input_node
,
const
KernelSelectPtr
&
kernel_select
,
const
std
::
vector
<
size_t
>
&
dst_shape
)
{
std
::
vector
<
AnfNodePtr
>
trans_inputs
;
...
...
@@ -110,13 +111,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
MS_EXCEPTION_IF_NULL
(
input_node
);
AnfAlgo
::
SetNodeInput
(
node
,
input_node
,
index
);
}
if
(
AnfAlgo
::
GetInputFormat
(
node
,
index
)
==
kOpFormat_NC1KHKWHWC0
)
{
MS_LOG
(
EXCEPTION
)
<<
"got the format "
<<
AnfAlgo
::
GetInputFormat
(
node
,
index
)
<<
"when inserting the transdata node "
<<
node
->
DebugString
();
}
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
node
,
index
);
std
::
string
dest_format
=
AnfAlgo
::
GetInputFormat
(
node
,
index
);
if
(
k
NeedTransFormatSet
.
find
(
dest_format
)
!=
kNeedTrans
FormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
if
(
k
CommonFormatSet
.
find
(
dest_format
)
==
kCommon
FormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
MS_LOG
(
DEBUG
)
<<
node
->
DebugString
()
<<
"Insert transdata "
<<
AnfAlgo
::
GetInputFormat
(
node
,
index
)
<<
" To DefaultFormat , index: "
<<
index
;
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
true
);
...
...
@@ -133,7 +130,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
MS_LOG
(
EXCEPTION
)
<<
"got the hw format "
<<
output_format
<<
"when insert the transdata node "
<<
node
->
DebugString
();
}
if
(
k
NeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTrans
FormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
if
(
k
CommonFormatSet
.
find
(
output_format
)
==
kCommon
FormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
MS_LOG
(
DEBUG
)
<<
"Inserted Transdata "
<<
output_format
<<
" To default , index :0"
;
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
false
);
}
...
...
@@ -154,7 +151,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
auto
tuple_getitem
=
CreatTupleGetItemNode
(
func_graph
,
node
,
output_idx
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_idx
);
if
(
k
NeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTrans
FormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
if
(
k
CommonFormatSet
.
find
(
output_format
)
==
kCommon
FormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
make_tuple_inputs
.
emplace_back
(
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
false
));
}
else
{
// No need insert trans op.
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
浏览文件 @
5d25bf7c
...
...
@@ -97,7 +97,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
std
::
string
convert_format
;
for
(
const
auto
&
do_mask
:
do_mask_node_list
)
{
auto
do_mask_data_format
=
AnfAlgo
::
GetInputFormat
(
do_mask
,
0
);
if
(
special_format
.
empty
()
&&
k
NeedTransFormatSet
.
find
(
do_mask_data_format
)
!=
kNeedTrans
FormatSet
.
end
())
{
if
(
special_format
.
empty
()
&&
k
HWSpecialFormatSet
.
find
(
do_mask_data_format
)
!=
kHWSpecial
FormatSet
.
end
())
{
special_format
=
do_mask_data_format
;
}
if
(
format_counter
.
find
(
do_mask_data_format
)
==
format_counter
.
end
())
{
...
...
@@ -111,7 +111,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
convert_format
=
kOpFormat_DEFAULT
;
break
;
}
if
(
k
NeedTransFormatSet
.
find
(
do_mask_data_format
)
!=
kNeedTrans
FormatSet
.
end
()
&&
if
(
k
HWSpecialFormatSet
.
find
(
do_mask_data_format
)
!=
kHWSpecial
FormatSet
.
end
()
&&
special_format
!=
do_mask_data_format
)
{
convert_format
=
kOpFormat_DEFAULT
;
break
;
...
...
@@ -133,7 +133,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
if
(
counter
<
iter
.
second
)
{
convert_format
=
iter
.
first
;
}
if
(
counter
==
iter
.
second
&&
k
NeedTransFormatSet
.
find
(
convert_format
)
==
kNeedTrans
FormatSet
.
end
())
{
if
(
counter
==
iter
.
second
&&
k
HWSpecialFormatSet
.
find
(
convert_format
)
==
kHWSpecial
FormatSet
.
end
())
{
convert_format
=
iter
.
first
;
}
}
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
5d25bf7c
...
...
@@ -265,7 +265,7 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName
,
};
const
std
::
set
<
std
::
string
>
k
NeedTrans
FormatSet
=
{
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
const
std
::
set
<
std
::
string
>
k
HWSpecial
FormatSet
=
{
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_C1HWNCoC0
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
};
...
...
mindspore/ops/_op_impl/tbe/trans_data.py
浏览文件 @
5d25bf7c
...
...
@@ -58,6 +58,8 @@ trans_data_op_info = TBERegOp("TransData") \
.
dtype_format
(
DataType
.
F32_HWCN
,
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F32_HWCN
,
DataType
.
F32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_HWCN
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_NCHW
)
\
.
dtype_format
(
DataType
.
F32_HWCN
,
DataType
.
F32_Default
)
\
.
get_op_info
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录