Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
44df45c8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
44df45c8
编写于
5月 13, 2020
作者:
W
WilliamLian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add attr op_pattern to kernel build info
上级
074a2f34
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
67 addition
and
25 deletion
+67
-25
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+3
-2
mindspore/ccsrc/kernel/kernel.h
mindspore/ccsrc/kernel/kernel.h
+7
-0
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+5
-0
mindspore/ccsrc/kernel/kernel_build_info.h
mindspore/ccsrc/kernel/kernel_build_info.h
+6
-0
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+1
-1
mindspore/ccsrc/kernel/oplib/opinfo.h
mindspore/ccsrc/kernel/oplib/opinfo.h
+4
-3
mindspore/ccsrc/kernel/oplib/oplib.cc
mindspore/ccsrc/kernel/oplib/oplib.cc
+11
-1
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
+11
-14
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
...end/format_type/convert_unsupported_transnode_to_aicpu.cc
+1
-0
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+10
-0
mindspore/ccsrc/session/anf_runtime_algorithm.h
mindspore/ccsrc/session/anf_runtime_algorithm.h
+2
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+6
-4
未找到文件。
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
44df45c8
...
...
@@ -425,7 +425,7 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
return
kernel_info_list
[
selected_index
];
}
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
GetAllMatchedFilteredKernelInfo
(
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
FilteredKernelInfoByDtype
(
const
CNodePtr
&
cnode
,
const
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
&
kernel_info_list
)
{
std
::
vector
<
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>>
result
;
for
(
const
auto
&
kernel_build_info
:
kernel_info_list
)
{
...
...
@@ -474,7 +474,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
>
selected_kernel_info
=
nullptr
;
// Matched kernel info
// Filter kernel info matched with me infered type
auto
filtered_kernel_info_list
=
GetAllMatchedFilteredKernelInfo
(
kernel_node
,
kernel_info_list
);
auto
filtered_kernel_info_list
=
FilteredKernelInfoByDtype
(
kernel_node
,
kernel_info_list
);
if
(
!
filtered_kernel_info_list
.
empty
())
{
selected_kernel_info
=
ChooseMatchedKernelInfo
(
kernel_node
,
filtered_kernel_info_list
);
select_status
=
kStatusAllMatched
;
...
...
@@ -508,6 +508,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
<<
"] cannot find valid TBE kernel info, try to get aicpu kernel info"
;
kernel
::
AICpuQuery
(
kernel_node
,
&
kernel_info_list
);
select_status
=
SetMatchedKernelInfo
(
kernel_node
,
kernel_info_list
);
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
kernel_node
);
}
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
if
(
select_status
==
kNoMatched
)
{
...
...
mindspore/ccsrc/kernel/kernel.h
浏览文件 @
44df45c8
...
...
@@ -47,6 +47,13 @@ enum FusionType {
OPAQUE
,
UNKNOWN_FUSION_TYPE
=
-
1
,
};
enum
OpPattern
{
kCommonPattern
=
0
,
kFormatAgnosticPattern
=
1
,
kBroadcastPattern
=
2
,
kReducePattern
=
3
,
kDynamicFormatPattern
=
4
,
};
// Backend processor
enum
Processor
{
...
...
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
44df45c8
...
...
@@ -162,5 +162,10 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
output_reshape_type_
=
output_reshape_type
;
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetOpPattern
(
OpPattern
pattern
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
op_pattern_
=
pattern
;
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/kernel_build_info.h
浏览文件 @
44df45c8
...
...
@@ -34,6 +34,7 @@ class KernelBuildInfo {
kernel_type_
=
AUTO_DIFF_KERNEL
;
fusion_type_
=
OPAQUE
;
processor_
=
AICORE
;
op_pattern_
=
kCommonPattern
;
input_reshape_type_
=
{};
output_reshape_type_
=
{};
inputs_format_
=
{};
...
...
@@ -70,6 +71,8 @@ class KernelBuildInfo {
std
::
vector
<
TypeId
>
GetAllOutputDeviceTypes
()
const
;
OpPattern
op_pattern
()
const
{
return
op_pattern_
;
}
FusionType
fusion_type
()
const
{
return
fusion_type_
;
}
Processor
processor
()
const
{
return
processor_
;
}
...
...
@@ -88,6 +91,7 @@ class KernelBuildInfo {
private:
KernelType
kernel_type_
;
std
::
vector
<
std
::
string
>
inputs_format_
;
OpPattern
op_pattern_
;
std
::
vector
<
std
::
string
>
outputs_format_
;
std
::
vector
<
std
::
vector
<
Axis
>>
input_reshape_type_
;
std
::
vector
<
std
::
vector
<
Axis
>>
output_reshape_type_
;
...
...
@@ -125,6 +129,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void
SetProcessor
(
Processor
processor
);
void
SetOpPattern
(
OpPattern
pattern
);
std
::
shared_ptr
<
KernelBuildInfo
>
Build
();
private:
...
...
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
44df45c8
...
...
@@ -40,7 +40,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
(
void
)
std
::
copy
(
filtered_list
.
begin
(),
filtered_list
.
end
(),
std
::
back_inserter
(
*
kernel_info_list
));
}
else
{
MS_LOG
(
WARNING
)
<<
"All kernel Info list does not match any kernel info "
;
for
(
size_t
index
;
index
<
kernel_info_list
->
size
();
++
index
)
{
for
(
size_t
index
=
0
;
index
<
kernel_info_list
->
size
();
++
index
)
{
MS_EXCEPTION_IF_NULL
(
kernel_info_list
->
at
(
index
));
MS_LOG
(
WARNING
)
<<
"kernel [ "
<<
index
<<
" ] :"
<<
kernel_info_list
->
at
(
index
)
->
ToString
();
}
...
...
mindspore/ccsrc/kernel/oplib/opinfo.h
浏览文件 @
44df45c8
...
...
@@ -21,6 +21,7 @@
#include <memory>
#include <unordered_map>
#include "ir/dtype.h"
#include "kernel/kernel.h"
namespace
mindspore
{
namespace
kernel
{
...
...
@@ -100,7 +101,7 @@ class OpInfo {
std
::
string
kernel_name
()
const
{
return
kernel_name_
;
}
bool
partial_flag
()
const
{
return
partial_flag_
;
}
bool
dynamic_format
()
const
{
return
dynamic_format_
;
}
std
::
string
op_pattern
()
const
{
return
op_pattern_
;
}
OpPattern
op_pattern
()
const
{
return
op_pattern_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
attrs_ptr
()
const
{
return
attrs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
inputs_ptr
()
const
{
return
inputs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
outputs_ptr
()
const
{
return
outputs_ptr_
;
}
...
...
@@ -116,7 +117,7 @@ class OpInfo {
void
set_kernel_name
(
const
std
::
string
&
kernel_name
)
{
kernel_name_
=
kernel_name
;
}
void
set_partial_flag
(
const
bool
partial_flag
)
{
partial_flag_
=
partial_flag
;
}
void
set_dynamic_format
(
const
bool
dynamic_format
)
{
dynamic_format_
=
dynamic_format
;
}
void
set_op_pattern
(
const
std
::
string
op_pattern
)
{
op_pattern_
=
op_pattern
;
}
void
set_op_pattern
(
const
OpPattern
op_pattern
)
{
op_pattern_
=
op_pattern
;
}
void
add_attrs_ptr
(
const
std
::
shared_ptr
<
OpAttr
>
&
attr
)
{
attrs_ptr_
.
push_back
(
attr
);
}
void
add_inputs_ptr
(
const
std
::
shared_ptr
<
OpIOInfo
>
&
input
)
{
inputs_ptr_
.
push_back
(
input
);
}
void
add_outputs_ptr
(
const
std
::
shared_ptr
<
OpIOInfo
>
&
output
)
{
outputs_ptr_
.
push_back
(
output
);
}
...
...
@@ -137,7 +138,7 @@ class OpInfo {
std
::
string
kernel_name_
;
bool
partial_flag_
=
false
;
bool
dynamic_format_
=
false
;
std
::
string
op_pattern_
;
OpPattern
op_pattern_
=
kCommonPattern
;
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
attrs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
inputs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
outputs_ptr_
;
...
...
mindspore/ccsrc/kernel/oplib/oplib.cc
浏览文件 @
44df45c8
...
...
@@ -18,6 +18,7 @@
#include <pybind11/pybind11.h>
#include <unordered_map>
#include <memory>
#include <map>
#include "utils/log_adapter.h"
#include "utils/overload.h"
#include "utils/context/ms_context.h"
...
...
@@ -35,6 +36,9 @@ constexpr auto kPartialFlag = "partial_flag";
constexpr
auto
kReshapeType
=
"reshape_type"
;
constexpr
auto
kOpPattern
=
"op_pattern"
;
constexpr
auto
kDynamicFormat
=
"dynamic_format"
;
constexpr
auto
kFormatAgnostic
=
"formatAgnostic"
;
constexpr
auto
kBroadcast
=
"broadcast"
;
constexpr
auto
kReduce
=
"reduce"
;
constexpr
auto
kDtypeFormat
=
"dtype_format"
;
constexpr
auto
kAttr
=
"attr"
;
constexpr
auto
kIputs
=
"inputs"
;
...
...
@@ -95,13 +99,19 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
}
void
OpLib
::
DecodeTBESpecificInfo
(
const
nlohmann
::
json
&
obj
,
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
)
{
const
std
::
map
<
std
::
string
,
kernel
::
OpPattern
>
kOpPatternMap
=
{{
kFormatAgnostic
,
kFormatAgnosticPattern
},
{
kFormatAgnostic
,
kBroadcastPattern
},
{
kReduce
,
kReducePattern
},
{
kDynamicFormat
,
kDynamicFormatPattern
}};
op_info
->
set_async_flag
(
obj
.
at
(
kAsyncFlag
));
op_info
->
set_binfile_name
(
obj
.
at
(
kBinfileName
));
op_info
->
set_compute_cost
(
obj
.
at
(
kComputeCost
));
op_info
->
set_kernel_name
(
obj
.
at
(
kKernelName
));
op_info
->
set_partial_flag
(
obj
.
at
(
kPartialFlag
));
if
(
obj
.
find
(
kOpPattern
)
!=
obj
.
end
())
{
op_info
->
set_op_pattern
(
obj
.
at
(
kOpPattern
));
if
(
kOpPatternMap
.
find
(
obj
.
at
(
kOpPattern
))
!=
kOpPatternMap
.
end
())
{
op_info
->
set_op_pattern
(
obj
.
at
(
kOpPattern
));
}
}
if
(
obj
.
find
(
kDynamicFormat
)
!=
obj
.
end
())
{
op_info
->
set_dynamic_format
(
obj
.
at
(
kDynamicFormat
));
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
浏览文件 @
44df45c8
...
...
@@ -492,6 +492,7 @@ void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuild
if
(
tbe
::
GetFusionType
(
fusion_type
)
!=
UNKNOWN_FUSION_TYPE
)
{
builder
->
SetFusionType
(
tbe
::
GetFusionType
(
fusion_type
));
}
builder
->
SetOpPattern
(
op_info_ptr
->
op_pattern
());
builder
->
SetKernelType
(
TBE_KERNEL
);
}
...
...
@@ -509,7 +510,7 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
if
(
primitive
->
GetAttr
(
"dyn_input_sizes"
)
!=
nullptr
)
{
dyn_input_sizes
=
GetValue
<
std
::
vector
<
int
>>
(
primitive
->
GetAttr
(
"dyn_input_sizes"
));
}
if
(
inputs
.
size
()
>
0
)
{
if
(
!
inputs
.
empty
()
)
{
MS_EXCEPTION_IF_NULL
(
inputs
[
0
]);
size_t
kernel_info_cnt
=
inputs
[
0
]
->
dtypes
().
size
();
for
(
size_t
j
=
0
;
j
<
kernel_info_cnt
;
j
++
)
{
...
...
@@ -624,21 +625,17 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
for
(
auto
parse_info
:
parse_info_list
)
{
if
(
context_ptr
->
execution_mode
()
==
kPynativeMode
)
{
kernel_info_list
->
push_back
(
parse_info
);
}
else
{
if
(
IsValidKernelInfo
(
kernel_node
,
*
(
parse_info
)))
{
if
(
CheckSupported
(
kernel_node
,
parse_info
))
{
kernel_info_list
->
push_back
(
parse_info
);
}
else
{
MS_LOG
(
INFO
)
<<
"CheckSupported Failed for TBE op"
<<
op_name
<<
" kernel info."
;
}
for
(
const
auto
&
parse_info
:
parse_info_list
)
{
if
(
IsValidKernelInfo
(
kernel_node
,
*
(
parse_info
)))
{
if
(
CheckSupported
(
kernel_node
,
parse_info
))
{
kernel_info_list
->
push_back
(
parse_info
);
}
else
{
MS_LOG
(
INFO
)
<<
"CheckSupported Failed for TBE op"
<<
op_name
<<
" kernel info."
;
}
}
}
if
(
kernel_info_list
->
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Tbe dose not have op ["
<<
op_name
<<
"]."
;
if
(
kernel_info_list
->
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Tbe dose not have op ["
<<
op_name
<<
"]."
;
}
}
}
}
// namespace kernel
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
浏览文件 @
44df45c8
...
...
@@ -44,6 +44,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
auto
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
(
kernel_builder_info
);
builder
->
SetKernelType
(
AICPU_KERNEL
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
node
.
get
());
AnfAlgo
::
SetNodeAttr
(
kAttrIsAICPUKernel
,
MakeValue
(
true
),
node
);
}
else
{
MS_LOG
(
EXCEPTION
)
<<
" kernel "
<<
kernel_builder_info
->
ToString
()
<<
"is not supported in AiCPU & AiCore : node ["
<<
node
->
DebugString
()
<<
"]"
;
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
44df45c8
...
...
@@ -657,6 +657,16 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_
to_node
->
set_abstract
(
from_node
->
abstract
());
}
kernel
::
OpPattern
AnfRuntimeAlgorithm
::
GetOpPattern
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
// select_kernel_build_info() has checked whether return pointer is null
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
return
build_info
->
op_pattern
();
}
// get KernelBuildType of node, such as ATT,RT,FWK and so on
KernelType
AnfRuntimeAlgorithm
::
GetKernelType
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.h
浏览文件 @
44df45c8
...
...
@@ -138,6 +138,8 @@ class AnfRuntimeAlgorithm {
static
void
SetOutputInferTypeAndShape
(
const
std
::
vector
<
TypeId
>
&
types
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
shapes
,
AnfNode
*
node
);
static
void
CopyAbstract
(
const
AnfNodePtr
&
from_node
,
AnfNode
*
to_node
);
// get op pattern of the node
static
kernel
::
OpPattern
GetOpPattern
(
const
AnfNodePtr
&
node
);
// get KernelBuildType of node ,such as ATT,RT,FWK and so on
static
KernelType
GetKernelType
(
const
AnfNodePtr
&
node
);
// get processor type:AICORE,AICPU...
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
44df45c8
...
...
@@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto";
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
constexpr
auto
kAttrIsAICPUKernel
=
"is_ai_cpu_kernel"
;
constexpr
auto
kIsBackendCast
=
"is_backed_cast"
;
constexpr
auto
kAttrOutputNames
=
"output_names"
;
constexpr
auto
kAttrVisited
=
"visited"
;
...
...
@@ -215,10 +216,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr
auto
kOpFormat_C1HWNCoC0
=
"C1HWNCoC0"
;
constexpr
auto
kOpFormat_NC1HWC0_C04
=
"NC1HWC0_C04"
;
constexpr
auto
kOpFormat_FRACTAL_Z_C04
=
"FRACTAL_Z_C04"
;
const
std
::
set
<
std
::
string
>
kOpFormatList
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_Z
,
kOpFormat_C1HWNCoC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
};
constexpr
auto
kOpFormat_NDHWC
=
"NDHWC"
;
const
std
::
set
<
std
::
string
>
kOpFormatList
=
{
kOpFormat_DEFAULT
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_Z
,
kOpFormat_C1HWNCoC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
,
kOpFormat_NDHWC
};
const
std
::
set
<
std
::
string
>
kDefaultCompatibleFormat
=
{
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
};
const
std
::
set
<
std
::
string
>
kOptOperatorSet
=
{
kMomentumOpName
,
kApplyMomentumOpName
,
kApplyAdadeltaOpName
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录