Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
51ca769d
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
51ca769d
编写于
4月 01, 2020
作者:
Z
zjun
提交者:
高东海
4月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add new mode for operator info register
上级
8f74aef6
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
535 addition
and
215 deletion
+535
-215
mindspore/ccsrc/kernel/oplib/opinfo.h
mindspore/ccsrc/kernel/oplib/opinfo.h
+9
-0
mindspore/ccsrc/kernel/oplib/oplib.cc
mindspore/ccsrc/kernel/oplib/oplib.cc
+52
-9
mindspore/ccsrc/kernel/oplib/oplib.h
mindspore/ccsrc/kernel/oplib/oplib.h
+3
-1
mindspore/ops/__init__.py
mindspore/ops/__init__.py
+2
-2
mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py
mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py
+32
-199
mindspore/ops/op_info_register.py
mindspore/ops/op_info_register.py
+437
-4
未找到文件。
mindspore/ccsrc/kernel/oplib/opinfo.h
浏览文件 @
51ca769d
...
@@ -61,6 +61,7 @@ class OpIOInfo {
...
@@ -61,6 +61,7 @@ class OpIOInfo {
std
::
string
name
()
const
{
return
name_
;
}
std
::
string
name
()
const
{
return
name_
;
}
bool
need_compile
()
const
{
return
need_compile_
;
}
bool
need_compile
()
const
{
return
need_compile_
;
}
std
::
string
param_type
()
const
{
return
param_type_
;
}
std
::
string
param_type
()
const
{
return
param_type_
;
}
std
::
string
reshape_type
()
const
{
return
reshape_type_
;
}
std
::
string
shape
()
const
{
return
shape_
;
}
std
::
string
shape
()
const
{
return
shape_
;
}
std
::
vector
<
std
::
string
>
dtypes
()
const
{
return
dtypes_
;
}
std
::
vector
<
std
::
string
>
dtypes
()
const
{
return
dtypes_
;
}
std
::
vector
<
std
::
string
>
formats
()
const
{
return
formats_
;
}
std
::
vector
<
std
::
string
>
formats
()
const
{
return
formats_
;
}
...
@@ -69,6 +70,7 @@ class OpIOInfo {
...
@@ -69,6 +70,7 @@ class OpIOInfo {
void
set_name
(
const
std
::
string
&
name
)
{
name_
=
name
;
}
void
set_name
(
const
std
::
string
&
name
)
{
name_
=
name
;
}
void
set_need_compile
(
const
bool
need_compile
)
{
need_compile_
=
need_compile
;
}
void
set_need_compile
(
const
bool
need_compile
)
{
need_compile_
=
need_compile
;
}
void
set_param_type
(
const
std
::
string
&
param_type
)
{
param_type_
=
param_type
;
}
void
set_param_type
(
const
std
::
string
&
param_type
)
{
param_type_
=
param_type
;
}
void
set_reshape_type
(
const
std
::
string
&
reshape_type
)
{
reshape_type_
=
reshape_type
;
}
void
set_shape
(
const
std
::
string
&
shape
)
{
shape_
=
shape
;
}
void
set_shape
(
const
std
::
string
&
shape
)
{
shape_
=
shape
;
}
void
set_dtypes
(
const
std
::
vector
<
std
::
string
>&
dtype
)
{
dtypes_
=
dtype
;
}
void
set_dtypes
(
const
std
::
vector
<
std
::
string
>&
dtype
)
{
dtypes_
=
dtype
;
}
void
set_formats
(
const
std
::
vector
<
std
::
string
>&
formats
)
{
formats_
=
formats
;
}
void
set_formats
(
const
std
::
vector
<
std
::
string
>&
formats
)
{
formats_
=
formats
;
}
...
@@ -78,6 +80,7 @@ class OpIOInfo {
...
@@ -78,6 +80,7 @@ class OpIOInfo {
std
::
string
name_
;
std
::
string
name_
;
bool
need_compile_
=
false
;
bool
need_compile_
=
false
;
std
::
string
param_type_
;
std
::
string
param_type_
;
std
::
string
reshape_type_
;
std
::
string
shape_
;
std
::
string
shape_
;
std
::
vector
<
std
::
string
>
dtypes_
;
std
::
vector
<
std
::
string
>
dtypes_
;
std
::
vector
<
std
::
string
>
formats_
;
std
::
vector
<
std
::
string
>
formats_
;
...
@@ -96,6 +99,8 @@ class OpInfo {
...
@@ -96,6 +99,8 @@ class OpInfo {
int
compute_cost
()
const
{
return
compute_cost_
;
}
int
compute_cost
()
const
{
return
compute_cost_
;
}
std
::
string
kernel_name
()
const
{
return
kernel_name_
;
}
std
::
string
kernel_name
()
const
{
return
kernel_name_
;
}
bool
partial_flag
()
const
{
return
partial_flag_
;
}
bool
partial_flag
()
const
{
return
partial_flag_
;
}
bool
dynamic_format
()
const
{
return
dynamic_format_
;
}
std
::
string
op_pattern
()
const
{
return
op_pattern_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
attrs_ptr
()
const
{
return
attrs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
attrs_ptr
()
const
{
return
attrs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
inputs_ptr
()
const
{
return
inputs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
inputs_ptr
()
const
{
return
inputs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
outputs_ptr
()
const
{
return
outputs_ptr_
;
}
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
outputs_ptr
()
const
{
return
outputs_ptr_
;
}
...
@@ -110,6 +115,8 @@ class OpInfo {
...
@@ -110,6 +115,8 @@ class OpInfo {
void
set_compute_cost
(
const
int
compute_cost
)
{
compute_cost_
=
compute_cost
;
}
void
set_compute_cost
(
const
int
compute_cost
)
{
compute_cost_
=
compute_cost
;
}
void
set_kernel_name
(
const
std
::
string
&
kernel_name
)
{
kernel_name_
=
kernel_name
;
}
void
set_kernel_name
(
const
std
::
string
&
kernel_name
)
{
kernel_name_
=
kernel_name
;
}
void
set_partial_flag
(
const
bool
partial_flag
)
{
partial_flag_
=
partial_flag
;
}
void
set_partial_flag
(
const
bool
partial_flag
)
{
partial_flag_
=
partial_flag
;
}
void
set_dynamic_format
(
const
bool
dynamic_format
)
{
dynamic_format_
=
dynamic_format
;
}
void
set_op_pattern
(
const
std
::
string
op_pattern
)
{
op_pattern_
=
op_pattern
;
}
void
add_attrs_ptr
(
const
std
::
shared_ptr
<
OpAttr
>&
attr
)
{
attrs_ptr_
.
push_back
(
attr
);
}
void
add_attrs_ptr
(
const
std
::
shared_ptr
<
OpAttr
>&
attr
)
{
attrs_ptr_
.
push_back
(
attr
);
}
void
add_inputs_ptr
(
const
std
::
shared_ptr
<
OpIOInfo
>&
input
)
{
inputs_ptr_
.
push_back
(
input
);
}
void
add_inputs_ptr
(
const
std
::
shared_ptr
<
OpIOInfo
>&
input
)
{
inputs_ptr_
.
push_back
(
input
);
}
void
add_outputs_ptr
(
const
std
::
shared_ptr
<
OpIOInfo
>&
output
)
{
outputs_ptr_
.
push_back
(
output
);
}
void
add_outputs_ptr
(
const
std
::
shared_ptr
<
OpIOInfo
>&
output
)
{
outputs_ptr_
.
push_back
(
output
);
}
...
@@ -129,6 +136,8 @@ class OpInfo {
...
@@ -129,6 +136,8 @@ class OpInfo {
int
compute_cost_
=
0
;
int
compute_cost_
=
0
;
std
::
string
kernel_name_
;
std
::
string
kernel_name_
;
bool
partial_flag_
=
false
;
bool
partial_flag_
=
false
;
bool
dynamic_format_
=
false
;
std
::
string
op_pattern_
;
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
attrs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpAttr
>>
attrs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
inputs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
inputs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
outputs_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
outputs_ptr_
;
...
...
mindspore/ccsrc/kernel/oplib/oplib.cc
浏览文件 @
51ca769d
...
@@ -26,18 +26,22 @@ namespace mindspore {
...
@@ -26,18 +26,22 @@ namespace mindspore {
namespace
kernel
{
namespace
kernel
{
constexpr
auto
kImplyType
=
"imply_type"
;
constexpr
auto
kImplyType
=
"imply_type"
;
constexpr
auto
kOpName
=
"op_name"
;
constexpr
auto
kOpName
=
"op_name"
;
constexpr
auto
kTbe
=
"TBE"
;
constexpr
auto
kAkg
=
"akg"
;
constexpr
auto
kAutodiff
=
"AutoDiff"
;
constexpr
auto
kFusionType
=
"fusion_type"
;
constexpr
auto
kFusionType
=
"fusion_type"
;
constexpr
auto
kAsyncFlag
=
"async_flag"
;
constexpr
auto
kAsyncFlag
=
"async_flag"
;
constexpr
auto
kBinfileName
=
"binfile_name"
;
constexpr
auto
kBinfileName
=
"binfile_name"
;
constexpr
auto
kComputeCost
=
"compute_cost"
;
constexpr
auto
kComputeCost
=
"compute_cost"
;
constexpr
auto
kKernelName
=
"kernel_name"
;
constexpr
auto
kKernelName
=
"kernel_name"
;
constexpr
auto
kPartialFlag
=
"partial_flag"
;
constexpr
auto
kPartialFlag
=
"partial_flag"
;
constexpr
auto
kReshapeType
=
"reshape_type"
;
constexpr
auto
kOpPattern
=
"op_pattern"
;
constexpr
auto
kDynamicFormat
=
"dynamic_format"
;
constexpr
auto
kDtypeFormat
=
"dtype_format"
;
constexpr
auto
kAttr
=
"attr"
;
constexpr
auto
kAttr
=
"attr"
;
constexpr
auto
kIputs
=
"inputs"
;
constexpr
auto
kIputs
=
"inputs"
;
constexpr
auto
kOutputs
=
"outputs"
;
constexpr
auto
kOutputs
=
"outputs"
;
constexpr
auto
kTbe
=
"TBE"
;
constexpr
auto
kAkg
=
"akg"
;
constexpr
auto
kAutodiff
=
"AutoDiff"
;
constexpr
auto
kName
=
"name"
;
constexpr
auto
kName
=
"name"
;
constexpr
auto
kParamType
=
"param_type"
;
constexpr
auto
kParamType
=
"param_type"
;
constexpr
auto
kDtype
=
"dtype"
;
constexpr
auto
kDtype
=
"dtype"
;
...
@@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
...
@@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
std
::
shared_ptr
<
OpInfo
>
op_info
=
std
::
make_shared
<
OpInfo
>
();
std
::
shared_ptr
<
OpInfo
>
op_info
=
std
::
make_shared
<
OpInfo
>
();
MS_EXCEPTION_IF_NULL
(
op_info
);
MS_EXCEPTION_IF_NULL
(
op_info
);
op_info
->
set_op_name
(
obj
.
at
(
kOpName
));
op_info
->
set_op_name
(
obj
.
at
(
kOpName
));
op_info
->
set_imply_type
(
imply_type
);
op_info
->
set_impl_path
(
impl_path
);
op_info
->
set_impl_path
(
impl_path
);
op_info
->
set_imply_type
(
imply_type
);
op_info
->
set_fusion_type
(
obj
.
at
(
kFusionType
));
op_info
->
set_fusion_type
(
obj
.
at
(
kFusionType
));
if
(
imply_type
==
kTBE
)
{
if
(
imply_type
==
kTBE
)
{
op_info
->
set_async_flag
(
obj
.
at
(
kAsyncFlag
));
op_info
->
set_async_flag
(
obj
.
at
(
kAsyncFlag
));
...
@@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
...
@@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
op_info
->
set_compute_cost
(
obj
.
at
(
kComputeCost
));
op_info
->
set_compute_cost
(
obj
.
at
(
kComputeCost
));
op_info
->
set_kernel_name
(
obj
.
at
(
kKernelName
));
op_info
->
set_kernel_name
(
obj
.
at
(
kKernelName
));
op_info
->
set_partial_flag
(
obj
.
at
(
kPartialFlag
));
op_info
->
set_partial_flag
(
obj
.
at
(
kPartialFlag
));
if
(
obj
.
find
(
kOpPattern
)
!=
obj
.
end
())
{
op_info
->
set_op_pattern
(
obj
.
at
(
kOpPattern
));
}
if
(
obj
.
find
(
kDynamicFormat
)
!=
obj
.
end
())
{
op_info
->
set_dynamic_format
(
obj
.
at
(
kDynamicFormat
));
}
}
}
auto
attrs
=
obj
.
at
(
kAttr
);
auto
attrs
=
obj
.
at
(
kAttr
);
for
(
const
auto
&
attr
:
attrs
)
{
for
(
const
auto
&
attr
:
attrs
)
{
...
@@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
...
@@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
return
false
;
return
false
;
}
}
}
}
nlohmann
::
json
dtype_format
;
if
(
obj
.
find
(
kDtypeFormat
)
!=
obj
.
end
())
{
dtype_format
=
obj
.
at
(
kDtypeFormat
);
}
auto
inputs
=
obj
.
at
(
kIputs
);
auto
inputs
=
obj
.
at
(
kIputs
);
for
(
const
auto
&
input
:
inputs
)
{
for
(
const
auto
&
input
:
inputs
)
{
if
(
!
DecodeInputOutput
(
input
,
imply_type
,
kInput
,
op_info
))
{
if
(
!
DecodeInputOutput
(
input
,
imply_type
,
kInput
,
op_info
,
dtype_format
))
{
MS_LOG
(
DEBUG
)
<<
"DecodeInputOutput Failed"
;
MS_LOG
(
DEBUG
)
<<
"DecodeInputOutput Failed"
;
return
false
;
return
false
;
}
}
}
}
auto
outputs
=
obj
.
at
(
kOutputs
);
auto
outputs
=
obj
.
at
(
kOutputs
);
for
(
const
auto
&
output
:
outputs
)
{
for
(
const
auto
&
output
:
outputs
)
{
if
(
!
DecodeInputOutput
(
output
,
imply_type
,
kOutput
,
op_info
))
{
if
(
!
DecodeInputOutput
(
output
,
imply_type
,
kOutput
,
op_info
,
dtype_format
))
{
MS_LOG
(
DEBUG
)
<<
"DecodeInputOutput Failed"
;
MS_LOG
(
DEBUG
)
<<
"DecodeInputOutput Failed"
;
return
false
;
return
false
;
}
}
...
@@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
...
@@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
return
ret
;
return
ret
;
}
}
bool
OpLib
::
DecodeDtypeFormat
(
const
nlohmann
::
json
&
dtype_format
,
const
std
::
shared_ptr
<
OpIOInfo
>&
op_io
,
size_t
index
)
{
bool
ret
=
true
;
try
{
std
::
vector
<
std
::
string
>
dtype
;
std
::
vector
<
std
::
string
>
format
;
for
(
const
auto
&
it
:
dtype_format
)
{
dtype
.
emplace_back
(
it
[
index
][
0
]);
format
.
emplace_back
(
it
[
index
][
1
]);
}
op_io
->
set_dtypes
(
dtype
);
op_io
->
set_formats
(
format
);
}
catch
(
const
std
::
exception
&
e
)
{
MS_LOG
(
ERROR
)
<<
"DecodeDtypeFormat falied"
<<
e
.
what
();
ret
=
false
;
}
return
ret
;
}
bool
OpLib
::
DecodeInputOutput
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
OpIOType
io_type
,
bool
OpLib
::
DecodeInputOutput
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
OpIOType
io_type
,
const
std
::
shared_ptr
<
OpInfo
>&
op_info
)
{
const
std
::
shared_ptr
<
OpInfo
>&
op_info
,
const
nlohmann
::
json
&
dtype_format
)
{
bool
ret
=
true
;
bool
ret
=
true
;
try
{
try
{
std
::
shared_ptr
<
OpIOInfo
>
op_io
=
std
::
make_shared
<
OpIOInfo
>
();
std
::
shared_ptr
<
OpIOInfo
>
op_io
=
std
::
make_shared
<
OpIOInfo
>
();
MS_EXCEPTION_IF_NULL
(
op_io
);
MS_EXCEPTION_IF_NULL
(
op_io
);
op_io
->
set_index
(
obj
.
at
(
kIndex
));
op_io
->
set_index
(
obj
.
at
(
kIndex
));
op_io
->
set_name
(
obj
.
at
(
kName
));
op_io
->
set_name
(
obj
.
at
(
kName
));
op_io
->
set_dtypes
(
obj
.
at
(
kDtype
));
if
(
!
dtype_format
.
empty
())
{
op_io
->
set_formats
(
obj
.
at
(
kFormat
));
if
(
!
DecodeDtypeFormat
(
dtype_format
,
op_io
,
op_info
->
inputs_ptr
().
size
()
+
op_info
->
outputs_ptr
().
size
()))
{
MS_LOG
(
ERROR
)
<<
"Decode dtype format failed"
;
return
false
;
}
}
else
{
op_io
->
set_dtypes
(
obj
.
at
(
kDtype
));
op_io
->
set_formats
(
obj
.
at
(
kFormat
));
}
if
(
op_io
->
dtypes
().
size
()
!=
op_io
->
formats
().
size
())
{
if
(
op_io
->
dtypes
().
size
()
!=
op_io
->
formats
().
size
())
{
MS_LOG
(
DEBUG
)
<<
"op"
<<
op_io
->
name
()
<<
"dtype size:"
<<
op_io
->
dtypes
()
MS_LOG
(
DEBUG
)
<<
"op"
<<
op_io
->
name
()
<<
"dtype size:"
<<
op_io
->
dtypes
()
<<
"is not equal to format size:"
<<
op_io
->
formats
();
<<
"is not equal to format size:"
<<
op_io
->
formats
();
...
@@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
...
@@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
if
(
obj
.
find
(
kShape
)
!=
obj
.
end
())
{
if
(
obj
.
find
(
kShape
)
!=
obj
.
end
())
{
op_io
->
set_shape
(
obj
.
at
(
kShape
));
op_io
->
set_shape
(
obj
.
at
(
kShape
));
}
}
if
(
obj
.
find
(
kReshapeType
)
!=
obj
.
end
())
{
op_io
->
set_reshape_type
(
obj
.
at
(
kReshapeType
));
}
}
}
if
(
io_type
==
kInput
)
{
if
(
io_type
==
kInput
)
{
...
...
mindspore/ccsrc/kernel/oplib/oplib.h
浏览文件 @
51ca769d
...
@@ -38,8 +38,10 @@ class OpLib {
...
@@ -38,8 +38,10 @@ class OpLib {
static
bool
DecodeOpInfo
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
std
::
string
&
impl_path
);
static
bool
DecodeOpInfo
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
std
::
string
&
impl_path
);
static
bool
DecodeAttr
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
static
bool
DecodeAttr
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
static
bool
DecodeDtypeFormat
(
const
nlohmann
::
json
&
dtype_format
,
const
std
::
shared_ptr
<
OpIOInfo
>&
op_io
,
size_t
index
);
static
bool
DecodeInputOutput
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
OpIOType
io_type
,
static
bool
DecodeInputOutput
(
const
nlohmann
::
json
&
obj
,
const
OpImplyType
imply_type
,
const
OpIOType
io_type
,
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
const
std
::
shared_ptr
<
OpInfo
>&
op_info
,
const
nlohmann
::
json
&
dtype_format
);
static
bool
GetRefInfo
(
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
static
bool
GetRefInfo
(
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
static
bool
CheckRepetition
(
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
static
bool
CheckRepetition
(
const
std
::
shared_ptr
<
OpInfo
>&
op_info
);
};
};
...
...
mindspore/ops/__init__.py
浏览文件 @
51ca769d
...
@@ -30,7 +30,7 @@ Note:
...
@@ -30,7 +30,7 @@ Note:
from
.primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
.primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
.vm_impl_registry
import
get_vm_impl_fn
,
vm_impl_registry
from
.vm_impl_registry
import
get_vm_impl_fn
,
vm_impl_registry
from
.op_info_register
import
op_info_register
from
.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
from
.primitive
import
constexpr
from
.primitive
import
constexpr
from
.._c_expression
import
signature_rw
,
signature_kind
from
.._c_expression
import
signature_rw
,
signature_kind
...
@@ -40,6 +40,6 @@ __primitive__ = [
...
@@ -40,6 +40,6 @@ __primitive__ = [
]
]
__all__
=
[
"get_vm_impl_fn"
,
"vm_impl_registry"
,
__all__
=
[
"get_vm_impl_fn"
,
"vm_impl_registry"
,
"op_info_register"
,
"op_info_register"
,
"TBERegOp"
,
"DataType"
,
"constexpr"
]
"constexpr"
]
__all__
.
extend
(
__primitive__
)
__all__
.
extend
(
__primitive__
)
mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py
浏览文件 @
51ca769d
...
@@ -14,208 +14,41 @@
...
@@ -14,208 +14,41 @@
# ============================================================================
# ============================================================================
"""AdamApplyOneWithDecay op"""
"""AdamApplyOneWithDecay op"""
from
mindspore.ops.op_info_register
import
op_info_register
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
adam_apply_one_with_decay_op_info
=
TBERegOp
(
"AdamApplyOneWithDecay"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"adam_apply_one_with_decay.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"adam_apply_one_with_decay"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"input0"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"input1"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"input2"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"input3"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"input4"
,
False
,
"required"
,
"all"
)
\
.
input
(
5
,
"mul0_x"
,
False
,
"required"
,
"all"
)
\
.
input
(
6
,
"mul1_x"
,
False
,
"required"
,
"all"
)
\
.
input
(
7
,
"mul2_x"
,
False
,
"required"
,
"all"
)
\
.
input
(
8
,
"mul3_x"
,
False
,
"required"
,
"all"
)
\
.
input
(
9
,
"mul4_x"
,
False
,
"required"
,
"all"
)
\
.
input
(
10
,
"add2_y"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"output0"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"output1"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"output2"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
"""{
"op_name": "AdamApplyOneWithDecay",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "adam_apply_one_with_decay.so",
"compute_cost": 10,
"kernel_name": "adam_apply_one_with_decay",
"partial_flag": true,
"attr": [
],
@
op_info_register
(
adam_apply_one_with_decay_op_info
)
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input0",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input3",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input4",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul0_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 6,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul1_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 7,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul2_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 8,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul3_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 9,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "mul4_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 10,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "add2_y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output0",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output1",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output2",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}"""
)
def
_adam_apply_one_with_decay_tbe
():
def
_adam_apply_one_with_decay_tbe
():
"""AdamApplyOneWithDecay TBE register"""
"""AdamApplyOneWithDecay TBE register"""
return
return
mindspore/ops/op_info_register.py
浏览文件 @
51ca769d
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Operators info register."""
"""Operators info register."""
import
os
import
os
import
json
import
inspect
import
inspect
from
mindspore._c_expression
import
Oplib
from
mindspore._c_expression
import
Oplib
from
mindspore._checkparam
import
ParamValidator
as
validator
from
mindspore._checkparam
import
ParamValidator
as
validator
...
@@ -32,21 +33,453 @@ def op_info_register(op_info):
...
@@ -32,21 +33,453 @@ def op_info_register(op_info):
'op_info' must be a str of json format represent the op info, the op info will be added into oplib.
'op_info' must be a str of json format represent the op info, the op info will be added into oplib.
Args:
Args:
op_info (str): op info of json format.
op_info (str
or dict
): op info of json format.
Returns:
Returns:
Function, returns a decorator for op info register.
Function, returns a decorator for op info register.
"""
"""
def
register_decorator
(
func
):
def
register_decorator
(
func
):
validator
.
check_type
(
"op_info"
,
op_info
,
[
str
])
if
isinstance
(
op_info
,
dict
):
op_info_real
=
json
.
dumps
(
op_info
)
else
:
op_info_real
=
op_info
validator
.
check_type
(
"op_info"
,
op_info_real
,
[
str
])
op_lib
=
Oplib
()
op_lib
=
Oplib
()
file_path
=
os
.
path
.
realpath
(
inspect
.
getfile
(
func
))
file_path
=
os
.
path
.
realpath
(
inspect
.
getfile
(
func
))
# keep the path custom ops implementation.
# keep the path custom ops implementation.
imply_path
=
""
if
BUILT_IN_OPS_REGISTER_PATH
in
file_path
else
file_path
imply_path
=
""
if
BUILT_IN_OPS_REGISTER_PATH
in
file_path
else
file_path
if
not
op_lib
.
reg_op
(
op_info
,
imply_path
):
if
not
op_lib
.
reg_op
(
op_info
_real
,
imply_path
):
raise
ValueError
(
'Invalid op info {}:
\n
{}
\n
'
.
format
(
file_path
,
op_info
))
raise
ValueError
(
'Invalid op info {}:
\n
{}
\n
'
.
format
(
file_path
,
op_info
_real
))
def
wrapped_function
(
*
args
,
**
kwargs
):
def
wrapped_function
(
*
args
,
**
kwargs
):
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
return
wrapped_function
return
wrapped_function
return
register_decorator
return
register_decorator
class
RegOp
():
"""
Base class for op info register.
Args:
op_name (str): Name of op.
inputs (list): Inputs inoformation of the op.
outputs (list): Outputs information of the op.
attr_ (list): Attribute information of the op.
dtype_format_ (list): Dtype and format information of the op.
"""
def
__init__
(
self
,
op_name
=
""
):
if
not
isinstance
(
op_name
,
str
):
raise
ValueError
(
"op name value must be string"
)
if
not
op_name
.
strip
():
raise
ValueError
(
"op name is empty"
)
self
.
op_name
=
op_name
self
.
inputs
=
[]
self
.
outputs
=
[]
self
.
attr_
=
[]
self
.
dtype_format_
=
[]
def
is_string
(
self
,
value
):
"""
Check if the value is a str type.
Args:
value: Parameter to to check.
Raises:
TypeError: If the type of value is not a str.
"""
if
not
isinstance
(
value
,
str
):
raise
TypeError
(
"%s value must be str"
%
str
(
value
))
def
is_int
(
self
,
value
):
"""
Check if the value is a int.
Args:
value: Parameter to to check.
Raises:
TypeError: If the type of value is not a int.
"""
if
not
isinstance
(
value
,
int
):
raise
TypeError
(
"%s value must be int"
%
str
(
value
))
def
is_bool
(
self
,
value
):
"""
Check if the value is a bool.
Args:
value: Parameter to to check.
Raises:
TypeError: If the type of value is not a bool.
"""
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"%s value must be bool"
%
str
(
value
))
def
dtype_format
(
self
,
*
args
):
"""
Register dtype and format.
Args:
args (tuple): Value of dtype and format.
Raises:
ValueError: If the size of args not equal to input size add output size.
TypeError: If the type of args is not tuple.
"""
if
len
(
self
.
inputs
)
+
len
(
self
.
outputs
)
!=
len
(
args
):
raise
ValueError
(
"input size add output size must be equal to detype format size"
)
dtype_format
=
[]
for
arg
in
args
:
if
not
isinstance
(
arg
,
tuple
)
or
len
(
arg
)
!=
2
:
raise
ValueError
(
"dtype and format value must be tuple of two elements"
)
self
.
is_string
(
arg
[
0
])
self
.
is_string
(
arg
[
1
])
dtype_format
.
append
(
arg
)
self
.
dtype_format_
.
append
(
tuple
(
dtype_format
))
return
self
def
get_op_info
(
self
):
"""
Return all registration information for this instance.
The '_' character ending the key is removed here for compatibility with previous version.
Key will be unified into an underlined form later.
"""
op_info
=
{}
for
key
,
value
in
self
.
__dict__
.
items
():
if
isinstance
(
key
,
str
)
and
key
.
endswith
(
'_'
):
op_info
[
key
.
rstrip
(
'_'
)]
=
value
else
:
op_info
[
key
]
=
value
return
op_info
class
TBERegOp
(
RegOp
):
"""Class for TBE op info register."""
def
__init__
(
self
,
op_name
=
""
):
super
(
TBERegOp
,
self
).
__init__
(
op_name
)
self
.
imply_type
=
"TBE"
self
.
fusion_type_
=
''
self
.
async_flag_
=
False
self
.
binfile_name_
=
''
self
.
compute_cost_
=
10
self
.
kernel_name_
=
''
self
.
partial_flag_
=
False
self
.
reshape_type_
=
''
self
.
dynamic_format_
=
False
self
.
op_pattern_
=
""
def
fusion_type
(
self
,
fusion_type
):
"""
Register fusion type.
Args:
fusion_type (str): Value of fusion type.
"""
self
.
is_string
(
fusion_type
)
self
.
fusion_type_
=
fusion_type
return
self
def
async_flag
(
self
,
async_flag
):
"""
Register async flag.
Args:
async_flag (bool): Value of async flag.
"""
self
.
is_bool
(
async_flag
)
self
.
async_flag_
=
async_flag
return
self
def
binfile_name
(
self
,
binfile_name
):
"""
Register binfile name.
Args:
binfile_name (str): Name of op binfile.
"""
self
.
is_string
(
binfile_name
)
self
.
binfile_name_
=
binfile_name
return
self
def
compute_cost
(
self
,
compute_cost
):
"""
Register compute cost.
Args:
compute_cost (int): Value of compute cost.
"""
self
.
is_int
(
compute_cost
)
self
.
compute_cost_
=
compute_cost
return
self
def
kernel_name
(
self
,
kernel_name
):
"""
Register kernel name.
Args:
kernel_name (str): Name of op kernel.
"""
self
.
is_string
(
kernel_name
)
self
.
kernel_name_
=
kernel_name
return
self
def
partial_flag
(
self
,
partial_flag
):
"""
Register partial flag.
Args:
partial_flag (bool): Value of partial flag.
"""
self
.
is_bool
(
partial_flag
)
self
.
partial_flag_
=
partial_flag
return
self
def
reshape_type
(
self
,
reshape_type
):
"""
Register reshape type.
Args:
reshape_type (str): Value of reshape type.
"""
self
.
is_string
(
reshape_type
)
self
.
reshape_type_
=
reshape_type
return
self
def
dynamic_format
(
self
,
dynamic_format
):
"""
Register dynamic format.
Args:
reshape_type (bool): Value of dynamic format.
"""
self
.
is_bool
(
dynamic_format
)
self
.
dynamic_format_
=
dynamic_format
return
self
def
op_pattern
(
self
,
pattern
=
None
):
"""
Register op pattern information.
Args:
pattern (str): Value of op pattern.
"""
if
pattern
is
not
None
and
self
.
istring
(
pattern
):
self
.
op_pattern_
=
pattern
return
self
def
attr
(
self
,
name
=
None
,
param_type
=
None
,
value_type
=
None
,
value
=
None
,
default_value
=
None
,
**
kwargs
):
"""
Register op attribute information.
Args:
name (str): Name of the attribute. Default: None.
param_type (str): Param type of the attribute. Default: None.
type (str): Type of the attribute. Default: None.
value (str): Value of the attribute. Default: None.
default_value (str): Default value of attribute. Default: None.
kwargs (dict): Other information for the attribute.
"""
param_list
=
[
name
,
param_type
,
value_type
,
value
,
default_value
]
attr_dict
=
{}
for
index
,
element
in
enumerate
(
param_list
):
if
element
is
not
None
:
self
.
is_string
(
element
)
if
index
==
0
:
attr_dict
[
"name"
]
=
element
elif
index
==
1
:
attr_dict
[
"param_type"
]
=
element
elif
index
==
2
:
attr_dict
[
"type"
]
=
element
elif
index
==
3
:
attr_dict
[
"value"
]
=
element
elif
index
==
4
:
attr_dict
[
"default_value"
]
=
element
if
kwargs
:
attr_dict
=
dict
(
attr_dict
,
**
kwargs
)
self
.
attr_
.
append
(
attr_dict
)
return
self
def
input
(
self
,
index
=
None
,
name
=
None
,
need_compile
=
None
,
param_type
=
None
,
shape
=
None
,
**
kwargs
):
"""
Register op input information.
Args:
index (int): Order of the input. Default: None.
name (str): Name of the input. Default: None.
need_compile (bool): The input need compile whether or not. Default: None.
param_type (str): Type of the input. Default: None.
shape (str): Shape of the input. Default: None.
kwargs (dict): Other information for the input.
"""
param_list
=
[
index
,
name
,
need_compile
,
param_type
,
shape
]
input_dict
=
{}
for
idx
,
element
in
enumerate
(
param_list
):
if
element
is
not
None
:
if
idx
==
0
:
self
.
is_int
(
element
)
input_dict
[
"index"
]
=
element
elif
idx
==
1
:
self
.
is_string
(
element
)
input_dict
[
"name"
]
=
element
elif
idx
==
2
:
self
.
is_bool
(
element
)
input_dict
[
"need_compile"
]
=
element
elif
idx
==
3
:
self
.
is_string
(
element
)
input_dict
[
"param_type"
]
=
element
elif
idx
==
4
:
self
.
is_string
(
element
)
input_dict
[
"shape"
]
=
element
if
kwargs
:
input_dict
=
dict
(
input_dict
,
**
kwargs
)
self
.
inputs
.
append
(
input_dict
)
return
self
def
output
(
self
,
index
=
None
,
name
=
None
,
need_compile
=
None
,
param_type
=
None
,
shape
=
None
,
**
kwargs
):
"""
Register op output information.
Args:
index (int): Order of the output. Default: None.
name (str): Name of the output. Default: None.
need_compile (bool): The output need compile whether or not. Default: None.
param_type (str): Type of the output. Default: None.
shape (str): Shape of the output. Default: None.
kwargs (dict): Other information for the output.
"""
param_list
=
[
index
,
name
,
need_compile
,
param_type
,
shape
]
output_dict
=
{}
for
idx
,
element
in
enumerate
(
param_list
):
if
element
is
not
None
:
if
idx
==
0
:
self
.
is_int
(
element
)
output_dict
[
"index"
]
=
element
elif
idx
==
1
:
self
.
is_string
(
element
)
output_dict
[
"name"
]
=
element
elif
idx
==
2
:
self
.
is_bool
(
element
)
output_dict
[
"need_compile"
]
=
element
elif
idx
==
3
:
self
.
is_string
(
element
)
output_dict
[
"param_type"
]
=
element
elif
idx
==
4
:
self
.
is_string
(
element
)
output_dict
[
"shape"
]
=
element
if
kwargs
:
output_dict
=
dict
(
output_dict
,
**
kwargs
)
self
.
outputs
.
append
(
output_dict
)
return
self
class
DataType
():
"""
Various combinations of dtype and formatself.
The current list below maybe not completed. If necessary, please add it.
"""
BOOL_None
=
(
"bool"
,
""
)
BOOL_Default
=
(
"bool"
,
"DefaultFormat"
)
BOOL_5HD
=
(
"bool"
,
"NC1HWC0"
)
BOOL_NCHW
=
(
"bool"
,
"NCHW"
)
BOOL_NHWC
=
(
"bool"
,
"NHWC"
)
BOOL_HWCN
=
(
"bool"
,
"HWCN"
)
I8_None
=
(
"int8"
,
""
)
I8_Default
=
(
"int8"
,
"DefaultFormat"
)
I8_5HD
=
(
"int8"
,
"NC1HWC0"
)
I8_FracZ
=
(
"int8"
,
"Fracz"
)
I8_FracNZ
=
(
"int8"
,
"FRACTAL_NZ"
)
I8_NCHW
=
(
"int8"
,
"NCHW"
)
I8_NHWC
=
(
"int8"
,
"NHWC"
)
I8_HWCN
=
(
"int8"
,
"HWCN"
)
U8_None
=
(
"uint8"
,
""
)
U8_Default
=
(
"uint8"
,
"DefaultFormat"
)
U8_5HD
=
(
"uint8"
,
"NC1HWC0"
)
U8_FracZ
=
(
"uint8"
,
"Fracz"
)
U8_FracNZ
=
(
"uint8"
,
"FRACTAL_NZ"
)
U8_NCHW
=
(
"uint8"
,
"NCHW"
)
U8_NHWC
=
(
"uint8"
,
"NHWC"
)
U8_HWCN
=
(
"uint8"
,
"HWCN"
)
I16_None
=
(
"int16"
,
""
)
I16_Default
=
(
"int16"
,
"DefaultFormat"
)
I16_5HD
=
(
"int16"
,
"NC1HWC0"
)
I16_FracZ
=
(
"int16"
,
"Fracz"
)
I16_FracNZ
=
(
"int16"
,
"FRACTAL_NZ"
)
I16_NCHW
=
(
"int16"
,
"NCHW"
)
I16_NHWC
=
(
"int16"
,
"NHWC"
)
I16_HWCN
=
(
"int16"
,
"HWCN"
)
U16_None
=
(
"uint16"
,
""
)
U16_Default
=
(
"uint16"
,
"DefaultFormat"
)
U16_5HD
=
(
"uint16"
,
"NC1HWC0"
)
U16_FracZ
=
(
"uint16"
,
"Fracz"
)
U16_FracNZ
=
(
"uint16"
,
"FRACTAL_NZ"
)
U16_NCHW
=
(
"uint16"
,
"NCHW"
)
U16_NHWC
=
(
"uint16"
,
"NHWC"
)
U16_HWCN
=
(
"uint16"
,
"HWCN"
)
I32_None
=
(
"int32"
,
""
)
I32_Default
=
(
"int32"
,
"DefaultFormat"
)
I32_5HD
=
(
"int32"
,
"NC1HWC0"
)
I32_FracZ
=
(
"int32"
,
"Fracz"
)
I32_FracNZ
=
(
"int32"
,
"FRACTAL_NZ"
)
I32_NCHW
=
(
"int32"
,
"NCHW"
)
I32_NHWC
=
(
"int32"
,
"NHWC"
)
I32_HWCN
=
(
"int32"
,
"HWCN"
)
U32_None
=
(
"uint32"
,
""
)
U32_Default
=
(
"uint32"
,
"DefaultFormat"
)
U32_5HD
=
(
"uint32"
,
"NC1HWC0"
)
U32_FracZ
=
(
"uint32"
,
"Fracz"
)
U32_FracNZ
=
(
"uint32"
,
"FRACTAL_NZ"
)
U32_NCHW
=
(
"uint32"
,
"NCHW"
)
U32_NHWC
=
(
"uint32"
,
"NHWC"
)
U32_HWCN
=
(
"uint32"
,
"HWCN"
)
I64_None
=
(
"int64"
,
""
)
I64_Default
=
(
"int64"
,
"DefaultFormat"
)
I64_5HD
=
(
"int64"
,
"NC1HWC0"
)
I64_FracZ
=
(
"int64"
,
"Fracz"
)
I64_FracNZ
=
(
"int64"
,
"FRACTAL_NZ"
)
I64_NCHW
=
(
"int64"
,
"NCHW"
)
I64_NHWC
=
(
"int64"
,
"NHWC"
)
I64_HWCN
=
(
"int64"
,
"HWCN"
)
U64_None
=
(
"uint64"
,
""
)
U64_Default
=
(
"uint64"
,
"DefaultFormat"
)
U64_5HD
=
(
"uint64"
,
"NC1HWC0"
)
U64_FracZ
=
(
"uint64"
,
"Fracz"
)
U64_FracNZ
=
(
"uint64"
,
"FRACTAL_NZ"
)
U64_NCHW
=
(
"uint64"
,
"NCHW"
)
U64_NHWC
=
(
"uint64"
,
"NHWC"
)
U64_HWCN
=
(
"uint64"
,
"HWCN"
)
F16_None
=
(
"float16"
,
""
)
F16_Default
=
(
"float16"
,
"DefaultFormat"
)
F16_5HD
=
(
"float16"
,
"NC1HWC0"
)
F16_FracZ
=
(
"float16"
,
"Fracz"
)
F16_FracNZ
=
(
"float16"
,
"FRACTAL_NZ"
)
F16_C1HWNCoC0
=
(
"float16"
,
"C1HWNCoC0"
)
F16_NCHW
=
(
"float16"
,
"NCHW"
)
F16_NHWC
=
(
"float16"
,
"NHWC"
)
F16_HWCN
=
(
"float16"
,
"HWCN"
)
F32_None
=
(
"float32"
,
""
)
F32_Default
=
(
"float32"
,
"DefaultFormat"
)
F32_5HD
=
(
"float32"
,
"NC1HWC0"
)
F32_FracZ
=
(
"float32"
,
"Fracz"
)
F32_FracNZ
=
(
"float32"
,
"FRACTAL_NZ"
)
F32_C1HWNCoC0
=
(
"float32"
,
"C1HWNCoC0"
)
F32_NCHW
=
(
"float32"
,
"NCHW"
)
F32_NHWC
=
(
"float32"
,
"NHWC"
)
F32_HWCN
=
(
"float32"
,
"HWCN"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录