Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
caab25e0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
caab25e0
编写于
6月 12, 2020
作者:
J
jjfeing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
tbe select broadcast reduce dynamic
上级
553432c9
变更
54
展开全部
隐藏空白更改
内联
并排
Showing
54 changed file
with
1519 addition
and
832 deletion
+1519
-832
mindspore/ccsrc/kernel/kernel_query.cc
mindspore/ccsrc/kernel/kernel_query.cc
+1
-3
mindspore/ccsrc/kernel/oplib/opinfo.h
mindspore/ccsrc/kernel/oplib/opinfo.h
+2
-0
mindspore/ccsrc/kernel/oplib/oplib.cc
mindspore/ccsrc/kernel/oplib/oplib.cc
+12
-7
mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc
mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc
+2
-2
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
+94
-55
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h
+8
-1
mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc
+1
-1
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
+0
-664
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h
+9
-11
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc
...el/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc
+319
-0
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h
...nel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h
+57
-0
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc
...ernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc
+180
-0
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h
...kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h
+52
-0
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc
...e/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc
+633
-0
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h
...re/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h
+77
-0
mindspore/ccsrc/parallel/ops_info/ops_utils.h
mindspore/ccsrc/parallel/ops_info/ops_utils.h
+7
-0
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
+0
-1
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
...end/format_type/convert_unsupported_transnode_to_aicpu.cc
+1
-1
mindspore/ops/_op_impl/tbe/abs.py
mindspore/ops/_op_impl/tbe/abs.py
+3
-6
mindspore/ops/_op_impl/tbe/abs_grad.py
mindspore/ops/_op_impl/tbe/abs_grad.py
+0
-1
mindspore/ops/_op_impl/tbe/add.py
mindspore/ops/_op_impl/tbe/add.py
+1
-0
mindspore/ops/_op_impl/tbe/add_n.py
mindspore/ops/_op_impl/tbe/add_n.py
+4
-11
mindspore/ops/_op_impl/tbe/batch_matmul.py
mindspore/ops/_op_impl/tbe/batch_matmul.py
+1
-0
mindspore/ops/_op_impl/tbe/bias_add.py
mindspore/ops/_op_impl/tbe/bias_add.py
+1
-0
mindspore/ops/_op_impl/tbe/bn_training_reduce.py
mindspore/ops/_op_impl/tbe/bn_training_reduce.py
+1
-0
mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py
mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py
+1
-0
mindspore/ops/_op_impl/tbe/bn_training_update_grad.py
mindspore/ops/_op_impl/tbe/bn_training_update_grad.py
+1
-0
mindspore/ops/_op_impl/tbe/bn_training_update_v2.py
mindspore/ops/_op_impl/tbe/bn_training_update_v2.py
+1
-0
mindspore/ops/_op_impl/tbe/cast.py
mindspore/ops/_op_impl/tbe/cast.py
+21
-26
mindspore/ops/_op_impl/tbe/concat.py
mindspore/ops/_op_impl/tbe/concat.py
+1
-0
mindspore/ops/_op_impl/tbe/conv2d.py
mindspore/ops/_op_impl/tbe/conv2d.py
+2
-2
mindspore/ops/_op_impl/tbe/dropout_do_mask.py
mindspore/ops/_op_impl/tbe/dropout_do_mask.py
+1
-0
mindspore/ops/_op_impl/tbe/elu.py
mindspore/ops/_op_impl/tbe/elu.py
+0
-2
mindspore/ops/_op_impl/tbe/erf.py
mindspore/ops/_op_impl/tbe/erf.py
+0
-2
mindspore/ops/_op_impl/tbe/erfc.py
mindspore/ops/_op_impl/tbe/erfc.py
+0
-2
mindspore/ops/_op_impl/tbe/expm1.py
mindspore/ops/_op_impl/tbe/expm1.py
+0
-2
mindspore/ops/_op_impl/tbe/fused_mul_add.py
mindspore/ops/_op_impl/tbe/fused_mul_add.py
+1
-0
mindspore/ops/_op_impl/tbe/layer_norm.py
mindspore/ops/_op_impl/tbe/layer_norm.py
+1
-0
mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py
mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py
+1
-0
mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py
mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py
+1
-0
mindspore/ops/_op_impl/tbe/mul.py
mindspore/ops/_op_impl/tbe/mul.py
+2
-15
mindspore/ops/_op_impl/tbe/real_div.py
mindspore/ops/_op_impl/tbe/real_div.py
+3
-4
mindspore/ops/_op_impl/tbe/reciprocal.py
mindspore/ops/_op_impl/tbe/reciprocal.py
+1
-0
mindspore/ops/_op_impl/tbe/reduce_mean.py
mindspore/ops/_op_impl/tbe/reduce_mean.py
+5
-5
mindspore/ops/_op_impl/tbe/relu_grad_v2.py
mindspore/ops/_op_impl/tbe/relu_grad_v2.py
+1
-1
mindspore/ops/_op_impl/tbe/select.py
mindspore/ops/_op_impl/tbe/select.py
+1
-0
mindspore/ops/_op_impl/tbe/sign.py
mindspore/ops/_op_impl/tbe/sign.py
+0
-3
mindspore/ops/_op_impl/tbe/softmax_grad_ext.py
mindspore/ops/_op_impl/tbe/softmax_grad_ext.py
+1
-0
mindspore/ops/_op_impl/tbe/softplus.py
mindspore/ops/_op_impl/tbe/softplus.py
+0
-2
mindspore/ops/_op_impl/tbe/softplus_grad.py
mindspore/ops/_op_impl/tbe/softplus_grad.py
+0
-2
mindspore/ops/_op_impl/tbe/split_d.py
mindspore/ops/_op_impl/tbe/split_d.py
+1
-0
mindspore/ops/_op_impl/tbe/tensor_add.py
mindspore/ops/_op_impl/tbe/tensor_add.py
+1
-0
mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py
mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py
+1
-0
mindspore/ops/op_info_register.py
mindspore/ops/op_info_register.py
+4
-0
未找到文件。
mindspore/ccsrc/kernel/kernel_query.cc
浏览文件 @
caab25e0
...
...
@@ -20,7 +20,7 @@
#include "kernel/aicpu/aicpu_kernel_metadata.h"
#include "kernel/rts/rt_kernel_info.h"
#include "kernel/hccl/hccl_kernel_metadata.h"
#include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/tbe/tbe_kernel_select
/tbe_kernel_select
.h"
#include "session/anf_runtime_algorithm.h"
namespace
mindspore
{
...
...
@@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL
(
kernel_node
);
MS_EXCEPTION_IF_NULL
(
kernel_info_list
);
TbeMetadataInfo
(
kernel_node
,
kernel_info_list
);
FilterInvalidKernelInfo
(
kernel_node
,
kernel_info_list
);
if
(
kernel_info_list
->
empty
())
{
AicpuMetadataInfo
(
kernel_node
,
kernel_info_list
);
if
(
!
kernel_info_list
->
empty
())
{
...
...
@@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
auto
cnode
=
kernel_node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
TbeMetadataInfo
(
cnode
,
&
kernel_info_list
);
FilterInvalidKernelInfo
(
cnode
,
&
kernel_info_list
);
return
std
::
any_of
(
kernel_info_list
.
begin
(),
kernel_info_list
.
end
(),
[
&
select_kernel_build_info
](
const
kernel
::
KernelBuildInfoPtr
item
)
{
MS_EXCEPTION_IF_NULL
(
item
);
...
...
mindspore/ccsrc/kernel/oplib/opinfo.h
浏览文件 @
caab25e0
...
...
@@ -126,6 +126,8 @@ class OpInfo {
bool
is_ref
()
const
{
return
!
ref_infos_
.
empty
();
}
bool
has_ref_index
(
size_t
out_index
)
const
{
return
ref_infos_
.
find
(
out_index
)
!=
ref_infos_
.
end
();
}
void
add_ref_pair
(
size_t
out_index
,
size_t
in_index
)
{
(
void
)
ref_infos_
.
emplace
(
out_index
,
in_index
);
}
void
ClearInputs
()
{
(
void
)
inputs_ptr_
.
clear
();
}
void
ClearOutputs
()
{
(
void
)
outputs_ptr_
.
clear
();
}
private:
std
::
string
op_name_
;
...
...
mindspore/ccsrc/kernel/oplib/oplib.cc
浏览文件 @
caab25e0
...
...
@@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name";
constexpr
auto
kPartialFlag
=
"partial_flag"
;
constexpr
auto
kReshapeType
=
"reshape_type"
;
constexpr
auto
kOpPattern
=
"op_pattern"
;
constexpr
auto
kDynamicFormat
=
"dynamic
_f
ormat"
;
constexpr
auto
kDynamicFormat
=
"dynamic
F
ormat"
;
constexpr
auto
kFormatAgnostic
=
"formatAgnostic"
;
constexpr
auto
kBroadcast
=
"broadcast"
;
constexpr
auto
kReduce
=
"reduce"
;
...
...
@@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
void
OpLib
::
DecodeTBESpecificInfo
(
const
nlohmann
::
json
&
obj
,
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
)
{
const
std
::
map
<
std
::
string
,
kernel
::
OpPattern
>
kOpPatternMap
=
{{
kFormatAgnostic
,
kFormatAgnosticPattern
},
{
k
FormatAgnostic
,
kBroadcastPattern
},
{
k
Broadcast
,
kBroadcastPattern
},
{
kReduce
,
kReducePattern
},
{
kDynamicFormat
,
kDynamicFormatPattern
}};
op_info
->
set_async_flag
(
obj
.
at
(
kAsyncFlag
));
...
...
@@ -108,14 +108,19 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info
->
set_compute_cost
(
obj
.
at
(
kComputeCost
));
op_info
->
set_kernel_name
(
obj
.
at
(
kKernelName
));
op_info
->
set_partial_flag
(
obj
.
at
(
kPartialFlag
));
if
(
obj
.
find
(
kOpPattern
)
!=
obj
.
end
())
{
if
(
kOpPatternMap
.
find
(
obj
.
at
(
kOpPattern
))
!=
kOpPatternMap
.
end
())
{
op_info
->
set_op_pattern
(
obj
.
at
(
kOpPattern
));
std
::
string
op_pattern
=
obj
.
at
(
kOpPattern
);
auto
find_iter
=
kOpPatternMap
.
find
(
op_pattern
);
if
(
find_iter
==
kOpPatternMap
.
end
())
{
if
(
!
op_pattern
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"Op pattern set value error: "
<<
op_pattern
;
}
op_info
->
set_op_pattern
(
kCommonPattern
);
}
else
{
op_info
->
set_op_pattern
(
find_iter
->
second
);
}
}
if
(
obj
.
find
(
kDynamicFormat
)
!=
obj
.
end
())
{
op_info
->
set_dynamic_format
(
obj
.
at
(
kDynamicFormat
));
}
}
bool
OpLib
::
DecodeOpInfo
(
const
nlohmann
::
json
&
obj
,
const
mindspore
::
kernel
::
OpImplyType
imply_type
,
...
...
mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc
浏览文件 @
caab25e0
...
...
@@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
{
TypeId
::
kNumberTypeInt64
,
"int64"
},
{
TypeId
::
kNumberTypeUInt
,
"uint"
},
{
TypeId
::
kNumberTypeUInt8
,
"uint8"
},
{
TypeId
::
kNumberTypeUInt16
,
"uint16"
},
{
TypeId
::
kNumberTypeUInt32
,
"uint32"
},
{
TypeId
::
kNumberTypeUInt64
,
"uint64"
},
{
TypeId
::
kNumberTypeBool
,
"
bool
"
},
{
TypeId
::
kNumberTypeBool
,
"
int8
"
},
};
const
std
::
map
<
std
::
string
,
std
::
string
>
type_str_maps
=
{
...
...
@@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) {
std
::
string
TypeIdToString
(
TypeId
type_id
)
{
auto
iter
=
type_id_str_maps
.
find
(
type_id
);
if
(
iter
==
type_id_str_maps
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"Illegal input dtype
.
"
<<
TypeIdLabel
(
type_id
);
MS_LOG
(
EXCEPTION
)
<<
"Illegal input dtype
:
"
<<
TypeIdLabel
(
type_id
);
}
return
iter
->
second
;
}
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
浏览文件 @
caab25e0
...
...
@@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
if
(
input_ptr
->
name
()
==
"input_indices"
&&
op_name
==
kTopKOpName
)
{
TbeAdapter
::
GenTopKV2IndicesTensorInfo
(
anf_node
,
real_input_index
,
input_list
,
creater_type_
);
}
else
{
// dtype : float16
auto
tensor_dtype
=
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
AnfAlgo
::
GetInputDeviceDataType
(
anf_node
,
real_input_index
)));
MS_EXCEPTION_IF_NULL
(
tensor_dtype
);
std
::
string
dtype
=
tensor_dtype
->
element
()
->
ToString
();
dtype
=
tbe
::
DtypeToString
(
dtype
);
// format
std
::
string
format
=
AnfAlgo
::
GetInputFormat
(
anf_node
,
real_input_index
);
if
(
format
==
kOpFormat_DEFAULT
)
{
format
=
kOpFormat_NCHW
;
}
else
if
(
format
==
kOpFormat_FRAC_Z
)
{
format
=
kOpFormat_FRACTAL_Z
;
}
nlohmann
::
json
input_desc_json
;
input_desc_json
[
"dtype"
]
=
dtype
;
input_desc_json
[
"name"
]
=
op_input_name
+
std
::
to_string
(
input_i
);
auto
dtype
=
GetDeviceInputType
(
anf_node
,
real_input_index
);
auto
format
=
GetDeviceInputFormat
(
anf_node
,
real_input_index
);
auto
shape
=
GetDeviceInputShape
(
anf_node
,
real_input_index
);
auto
ori_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
anf_node
,
real_input_index
);
if
(
ori_shape
.
empty
())
{
ori_shape
.
emplace_back
(
1
);
}
nlohmann
::
json
input_desc_json
;
input_desc_json
[
"dtype"
]
=
dtype
;
input_desc_json
[
"name"
]
=
op_input_name
+
std
::
to_string
(
input_i
);
input_desc_json
[
"ori_shape"
]
=
ori_shape
;
input_desc_json
[
"ori_format"
]
=
kOpFormat_NCHW
;
auto
shape
=
AnfAlgo
::
GetInputDeviceShape
(
anf_node
,
real_input_index
);
if
(
shape
.
empty
())
{
shape
.
emplace_back
(
1
);
}
if
(
creater_type_
==
OP_SELECT_FORMAT
||
creater_type_
==
CHECK_SUPPORTED
)
{
input_desc_json
[
"shape"
]
=
ori_shape
;
input_desc_json
[
"format"
]
=
kOpFormat_NCHW
;
}
else
{
input_desc_json
[
"shape"
]
=
shape
;
input_desc_json
[
"format"
]
=
format
;
}
input_desc_json
[
"shape"
]
=
shape
;
input_desc_json
[
"format"
]
=
format
;
input_desc_json
[
"valid"
]
=
value
;
input_desc_json
[
"param_type"
]
=
input_ptr
->
param_type
();
input_list
->
emplace_back
(
input_desc_json
);
...
...
@@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
MS_EXCEPTION_IF_NULL
(
output_idx
);
MS_EXCEPTION_IF_NULL
(
output_list
);
for
(
size_t
i
=
0
;
i
<
output_obj_num
;
i
++
)
{
nlohmann
::
json
output_obj
;
auto
type_ptr
=
std
::
make_shared
<
TensorType
>
(
TypeIdToType
(
AnfAlgo
::
GetOutputDeviceDataType
(
anf_node
,
*
output_idx
)));
std
::
string
dtype
=
type_ptr
->
element
()
->
ToString
();
dtype
=
tbe
::
DtypeToString
(
dtype
);
std
::
string
format
=
AnfAlgo
::
GetOutputFormat
(
anf_node
,
*
output_idx
);
if
(
format
==
kOpFormat_DEFAULT
)
{
format
=
kOpFormat_NCHW
;
}
else
if
(
format
==
kOpFormat_FRAC_Z
)
{
format
=
kOpFormat_FRACTAL_Z
;
}
std
::
vector
<
size_t
>
ori_shape
;
if
(
AnfAlgo
::
GetOutputInferShape
(
anf_node
,
*
output_idx
).
empty
())
{
auto
dtype
=
GetDeviceOutputType
(
anf_node
,
*
output_idx
);
auto
format
=
GetDeviceOutputFormat
(
anf_node
,
*
output_idx
);
auto
shape
=
GetDeviceOutputShape
(
anf_node
,
*
output_idx
);
std
::
vector
<
size_t
>
ori_shape
=
AnfAlgo
::
GetOutputInferShape
(
anf_node
,
*
output_idx
);
if
(
ori_shape
.
empty
())
{
ori_shape
.
emplace_back
(
1
);
}
else
{
ori_shape
=
AnfAlgo
::
GetOutputInferShape
(
anf_node
,
*
output_idx
);
}
nlohmann
::
json
output_obj
;
output_obj
[
"dtype"
]
=
dtype
;
auto
shape
=
AnfAlgo
::
GetOutputDeviceShape
(
anf_node
,
*
output_idx
);
if
(
shape
.
empty
())
{
shape
.
emplace_back
(
1
);
}
if
(
creater_type_
==
OP_SELECT_FORMAT
||
creater_type_
==
CHECK_SUPPORTED
)
{
output_obj
[
"shape"
]
=
ori_shape
;
output_obj
[
"format"
]
=
kOpFormat_NCHW
;
}
else
{
output_obj
[
"shape"
]
=
shape
;
output_obj
[
"format"
]
=
format
;
}
output_obj
[
"shape"
]
=
shape
;
output_obj
[
"format"
]
=
format
;
output_obj
[
"ori_shape"
]
=
ori_shape
;
output_obj
[
"ori_format"
]
=
kOpFormat_NCHW
;
output_obj
[
"name"
]
=
output_ptr
->
name
();
output_obj
[
"valid"
]
=
true
;
output_obj
[
"param_type"
]
=
output_ptr
->
param_type
();
output_list
->
emplace_back
(
output_obj
);
(
*
output_idx
)
++
;
}
...
...
@@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
}
}
std
::
vector
<
size_t
>
TbeKernelJsonCreator
::
GetDeviceInputShape
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
std
::
vector
<
size_t
>
shape
;
if
(
creater_type_
==
OP_SELECT_FORMAT
||
creater_type_
==
CHECK_SUPPORTED
)
{
shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
anf_node
,
real_index
);
}
else
{
shape
=
AnfAlgo
::
GetInputDeviceShape
(
anf_node
,
real_index
);
}
if
(
shape
.
empty
())
{
shape
.
emplace_back
(
1
);
}
return
shape
;
}
std
::
string
TbeKernelJsonCreator
::
GetDeviceInputType
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
TypeId
type_id
;
if
(
creater_type_
==
OP_SELECT_FORMAT
)
{
type_id
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
anf_node
,
real_index
);
}
else
{
type_id
=
AnfAlgo
::
GetInputDeviceDataType
(
anf_node
,
real_index
);
}
return
tbe
::
TypeIdToString
(
type_id
);
}
std
::
string
TbeKernelJsonCreator
::
GetDeviceInputFormat
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
std
::
string
format
=
kOpFormat_NCHW
;
if
(
creater_type_
!=
OP_SELECT_FORMAT
&&
creater_type_
!=
CHECK_SUPPORTED
)
{
format
=
AnfAlgo
::
GetInputFormat
(
anf_node
,
real_index
);
if
(
format
==
kOpFormat_FRAC_Z
)
{
format
=
kOpFormat_FRACTAL_Z
;
}
else
if
(
format
==
kOpFormat_DEFAULT
)
{
format
=
kOpFormat_NCHW
;
}
}
return
format
;
}
std
::
vector
<
size_t
>
TbeKernelJsonCreator
::
GetDeviceOutputShape
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
std
::
vector
<
size_t
>
shape
;
if
(
creater_type_
==
OP_SELECT_FORMAT
||
creater_type_
==
CHECK_SUPPORTED
)
{
shape
=
AnfAlgo
::
GetOutputInferShape
(
anf_node
,
real_index
);
}
else
{
shape
=
AnfAlgo
::
GetOutputDeviceShape
(
anf_node
,
real_index
);
}
if
(
shape
.
empty
())
{
shape
.
emplace_back
(
1
);
}
return
shape
;
}
std
::
string
TbeKernelJsonCreator
::
GetDeviceOutputType
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
TypeId
type_id
;
if
(
creater_type_
==
OP_SELECT_FORMAT
)
{
type_id
=
AnfAlgo
::
GetOutputInferDataType
(
anf_node
,
real_index
);
}
else
{
type_id
=
AnfAlgo
::
GetOutputDeviceDataType
(
anf_node
,
real_index
);
}
return
tbe
::
TypeIdToString
(
type_id
);
}
std
::
string
TbeKernelJsonCreator
::
GetDeviceOutputFormat
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
{
MS_EXCEPTION_IF_NULL
(
anf_node
);
std
::
string
format
=
kOpFormat_NCHW
;
if
(
creater_type_
!=
OP_SELECT_FORMAT
&&
creater_type_
!=
CHECK_SUPPORTED
)
{
format
=
AnfAlgo
::
GetOutputFormat
(
anf_node
,
real_index
);
if
(
format
==
kOpFormat_FRAC_Z
)
{
format
=
kOpFormat_FRACTAL_Z
;
}
else
if
(
format
==
kOpFormat_DEFAULT
)
{
format
=
kOpFormat_NCHW
;
}
}
return
format
;
}
bool
TbeKernelBuild
::
GetIOSize
(
const
nlohmann
::
json
&
kernel_json
,
std
::
vector
<
size_t
>
*
input_size_list
,
std
::
vector
<
size_t
>
*
output_size_list
)
{
if
(
input_size_list
==
nullptr
||
output_size_list
==
nullptr
)
{
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h
浏览文件 @
caab25e0
...
...
@@ -93,7 +93,7 @@ class TbeKernelJsonCreator {
nlohmann
::
json
*
outputs_json
);
bool
GenTbeAttrJson
(
const
std
::
shared_ptr
<
AnfNode
>
&
anf_node
,
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
,
nlohmann
::
json
*
attrs_json
);
void
ParseAttrValue
(
const
std
::
string
&
type
,
const
ValuePtr
&
value
,
nlohmann
::
json
*
attr_obj
);
static
void
ParseAttrValue
(
const
std
::
string
&
type
,
const
ValuePtr
&
value
,
nlohmann
::
json
*
attr_obj
);
bool
GenInputDescJson
(
const
std
::
shared_ptr
<
AnfNode
>
&
anf_node
,
size_t
real_input_index
,
bool
value
,
const
std
::
shared_ptr
<
OpIOInfo
>
&
input_ptr
,
const
string
&
op_input_name
,
size_t
input_i
,
std
::
vector
<
nlohmann
::
json
>
*
input_list
);
...
...
@@ -105,6 +105,13 @@ class TbeKernelJsonCreator {
void
GenOutputList
(
const
std
::
shared_ptr
<
AnfNode
>
&
anf_node
,
const
size_t
&
output_obj_num
,
const
std
::
shared_ptr
<
OpIOInfo
>
&
output_ptr
,
size_t
*
output_idx
,
std
::
vector
<
nlohmann
::
json
>
*
output_list
);
std
::
vector
<
size_t
>
GetDeviceInputShape
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
;
std
::
string
GetDeviceInputType
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
;
std
::
string
GetDeviceInputFormat
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
;
std
::
vector
<
size_t
>
GetDeviceOutputShape
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
;
std
::
string
GetDeviceOutputType
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
;
std
::
string
GetDeviceOutputFormat
(
const
AnfNodePtr
&
anf_node
,
size_t
real_index
)
const
;
kCreaterType
creater_type_
;
std
::
string
json_name_
;
std
::
string
json_info_
;
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc
浏览文件 @
caab25e0
...
...
@@ -230,7 +230,7 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
task_iter
->
second
.
output_size_list
,
kernel_pack
);
MS_EXCEPTION_IF_NULL
(
kernel_mod
);
if
(
set_kernel_mod
)
{
AnfAlgo
::
SetKernelMod
(
kernel_mod
,
task_iter
->
second
.
node
);
AnfAlgo
::
SetKernelMod
(
kernel_mod
,
task_iter
->
second
.
node
);
}
auto
ret
=
std
::
make_pair
(
task_iter
->
second
.
scope_id
,
kernel_mod
);
(
void
)
task_map_
.
erase
(
task_iter
);
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
已删除
100644 → 0
浏览文件 @
553432c9
此差异已折叠。
点击以展开。
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h
→
mindspore/ccsrc/kernel/tbe/tbe_kernel_select
/common_utils
.h
浏览文件 @
caab25e0
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
@@ -13,20 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
#define MINDSPORE_TBE_KERNEL_SELECT_H
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
namespace
mindspore
{
namespace
kernel
{
void
TbeMetadataInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>
*
kernel_info_list
);
struct
SupportFormat
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
input_format
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
output_format
;
};
using
SupportFormatItem
=
std
::
vector
<
std
::
string
>
;
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_
TBE_KERNEL_SELECT_H
#endif // MINDSPORE_
CCSRC_KERNEL_TBE_COMMON_UTILS_H_
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc
0 → 100644
浏览文件 @
caab25e0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace
mindspore
{
namespace
kernel
{
constexpr
char
kDynInputKey
[]
=
"dyn_input_sizes"
;
constexpr
size_t
kInputIndex_0
=
0
;
constexpr
size_t
kChannelN
=
0
;
constexpr
size_t
kChannelC
=
1
;
constexpr
size_t
kAlignmented16
=
16
;
// 1. all shape no scalar and same
// 2. part scalar : no_scalar (shape size > xxx && alig xxx)
// 3. all no_scalar and not same (broad cast xxx dim)
bool
TbeKernelBroadCastSelecter
::
GetShapeInfo
(
SupportFormat
*
support_format
)
{
MS_EXCEPTION_IF_NULL
(
support_format
);
input_num_
=
0
;
output_num_
=
0
;
input_shapes_
.
clear
();
output_shapes_
.
clear
();
if
(
AnfAlgo
::
HasNodeAttr
(
kDynInputKey
,
cnode_ptr_
))
{
MS_LOG
(
INFO
)
<<
"This broadcast node has dynamic input."
;
auto
dynamic_size_vec
=
AnfAlgo
::
GetNodeAttr
<
std
::
vector
<
int
>>
(
cnode_ptr_
,
kDynInputKey
);
if
(
dynamic_size_vec
.
empty
()
||
dynamic_size_vec
[
0
]
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"dynamic attr set error, please check."
;
}
auto
dynamic_input_shape0_
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
cnode_ptr_
,
kInputIndex_0
);
PadScalarShape
(
&
dynamic_input_shape0_
);
input_shapes_
.
emplace_back
(
dynamic_input_shape0_
);
input_num_
=
1
;
}
else
{
input_num_
=
AnfAlgo
::
GetInputTensorNum
(
cnode_ptr_
);
for
(
size_t
i
=
0
;
i
<
input_num_
;
++
i
)
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
cnode_ptr_
,
i
);
PadScalarShape
(
&
input_shape
);
input_shapes_
.
emplace_back
(
input_shape
);
}
}
output_num_
=
AnfAlgo
::
GetOutputTensorNum
(
cnode_ptr_
);
for
(
size_t
i
=
0
;
i
<
output_num_
;
++
i
)
{
auto
output
=
AnfAlgo
::
GetOutputInferShape
(
cnode_ptr_
,
i
);
PadScalarShape
(
&
output
);
output_shapes_
.
emplace_back
(
output
);
}
AssignSupportFormat
(
kOpFormat_DEFAULT
,
support_format
);
return
true
;
}
bool
TbeKernelBroadCastSelecter
::
IsBroadCastSupport5HD
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
IsSameShape
())
{
if
(
!
HasScalarInput
())
{
AssignSupportFormat
(
kOpFormat_NC1HWC0
,
support_format
);
return
true
;
}
else
{
return
false
;
}
}
SupportFormatItem
input_support_format
;
SupportFormatItem
output_support_format
;
if
(
HasScalarInput
())
{
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
IsScalarShape
(
shape
))
{
input_support_format
.
emplace_back
(
kOpFormat_DEFAULT
);
}
else
{
if
(
!
Is4DShape
(
shape
))
{
return
false
;
}
if
(
shape
[
kChannelC
]
%
kAlignmented16
!=
0
)
{
return
false
;
}
input_support_format
.
emplace_back
(
kOpFormat_NC1HWC0
);
}
}
}
else
{
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
!
Is4DShape
(
shape
))
{
return
false
;
}
}
auto
shape_tmp
=
input_shapes_
[
0
];
auto
broadcast_c_axis
=
std
::
any_of
(
input_shapes_
.
begin
(),
input_shapes_
.
end
(),
[
&
shape_tmp
](
const
std
::
vector
<
size_t
>
&
elem
)
{
return
shape_tmp
.
at
(
kChannelC
)
!=
elem
.
at
(
kChannelC
);
});
if
(
broadcast_c_axis
)
{
MS_LOG
(
INFO
)
<<
"This node broadcast c channel."
;
return
false
;
}
input_support_format
.
assign
(
input_num_
,
kOpFormat_NC1HWC0
);
}
GenOutputSupportFormat
(
kOpFormat_NC1HWC0
,
&
output_support_format
);
support_format
->
input_format
.
emplace_back
(
input_support_format
);
support_format
->
output_format
.
emplace_back
(
output_support_format
);
return
true
;
}
bool
TbeKernelBroadCastSelecter
::
IsBroadCastSupportFracZ
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
IsSameShape
())
{
if
(
!
HasScalarInput
())
{
AssignSupportFormat
(
kOpFormat_FRAC_Z
,
support_format
);
return
true
;
}
else
{
return
false
;
}
}
SupportFormatItem
input_support_format
;
SupportFormatItem
output_support_format
;
if
(
HasScalarInput
())
{
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
IsScalarShape
(
shape
))
{
input_support_format
.
emplace_back
(
kOpFormat_DEFAULT
);
}
else
{
if
(
!
Is4DShape
(
shape
))
{
return
false
;
}
if
(
shape
[
kChannelN
]
%
kAlignmented16
!=
0
||
shape
[
kChannelC
]
%
kAlignmented16
!=
0
)
{
return
false
;
}
input_support_format
.
emplace_back
(
kOpFormat_FRAC_Z
);
}
}
}
else
{
return
false
;
}
GenOutputSupportFormat
(
kOpFormat_FRAC_Z
,
&
output_support_format
);
support_format
->
input_format
.
emplace_back
(
input_support_format
);
support_format
->
output_format
.
emplace_back
(
output_support_format
);
return
true
;
}
bool
TbeKernelBroadCastSelecter
::
IsBroadCastSupportC1HWNCoC0
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
IsSameShape
())
{
if
(
!
HasScalarInput
())
{
AssignSupportFormat
(
kOpFormat_C1HWNCoC0
,
support_format
);
return
true
;
}
else
{
return
false
;
}
}
SupportFormatItem
input_support_format
;
SupportFormatItem
output_support_format
;
if
(
HasScalarInput
())
{
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
IsScalarShape
(
shape
))
{
input_support_format
.
emplace_back
(
kOpFormat_DEFAULT
);
}
else
{
if
(
!
Is4DShape
(
shape
))
{
return
false
;
}
if
(
shape
[
kChannelN
]
%
kAlignmented16
!=
0
)
{
return
false
;
}
input_support_format
.
emplace_back
(
kOpFormat_C1HWNCoC0
);
}
}
}
else
{
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
!
Is4DShape
(
shape
))
{
return
false
;
}
}
auto
shape_tmp
=
input_shapes_
[
0
];
auto
broadcast_nc_axis
=
std
::
any_of
(
input_shapes_
.
begin
(),
input_shapes_
.
end
(),
[
&
shape_tmp
](
const
std
::
vector
<
size_t
>
&
elem
)
{
return
(
shape_tmp
.
at
(
kChannelC
)
!=
elem
.
at
(
kChannelC
)
||
shape_tmp
.
at
(
kChannelN
)
!=
elem
.
at
(
kChannelN
));
});
if
(
broadcast_nc_axis
)
{
MS_LOG
(
INFO
)
<<
"This node broadcast n || c channel."
;
return
false
;
}
input_support_format
.
assign
(
input_num_
,
kOpFormat_C1HWNCoC0
);
}
GenOutputSupportFormat
(
kOpFormat_C1HWNCoC0
,
&
output_support_format
);
support_format
->
input_format
.
emplace_back
(
input_support_format
);
support_format
->
output_format
.
emplace_back
(
output_support_format
);
return
true
;
}
bool
TbeKernelBroadCastSelecter
::
IsBroadCastSupportFracNZ
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
IsSameShape
())
{
if
(
!
HasScalarInput
())
{
AssignSupportFormat
(
kOpFormat_FRAC_NZ
,
support_format
);
return
true
;
}
else
{
return
false
;
}
}
SupportFormatItem
input_support_format
;
SupportFormatItem
output_support_format
;
if
(
HasScalarInput
())
{
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
IsScalarShape
(
shape
))
{
input_support_format
.
emplace_back
(
kOpFormat_DEFAULT
);
}
else
{
if
(
shape
.
size
()
<
kShape2dDims
)
{
return
false
;
}
if
(
shape
[
shape
.
size
()
-
1
]
%
kAlignmented16
!=
0
||
shape
[
shape
.
size
()
-
2
]
%
kAlignmented16
!=
0
)
{
return
false
;
}
input_support_format
.
emplace_back
(
kOpFormat_FRAC_NZ
);
}
}
}
else
{
auto
less_2dims
=
std
::
any_of
(
input_shapes_
.
begin
(),
input_shapes_
.
end
(),
[](
const
std
::
vector
<
size_t
>
&
elem
)
{
return
elem
.
size
()
<
kShape2dDims
;
});
if
(
less_2dims
)
{
MS_LOG
(
INFO
)
<<
"This node dim less 2."
;
return
false
;
}
auto
shape_tmp
=
input_shapes_
[
0
];
auto
broadcast_last_dim
=
std
::
any_of
(
input_shapes_
.
begin
(),
input_shapes_
.
end
(),
[
&
shape_tmp
](
const
std
::
vector
<
size_t
>
&
elem
)
{
return
(
shape_tmp
.
at
(
shape_tmp
.
size
()
-
1
)
!=
elem
.
at
(
elem
.
size
()
-
1
))
||
(
shape_tmp
.
at
(
shape_tmp
.
size
()
-
2
)
!=
elem
.
at
(
elem
.
size
()
-
2
));
});
if
(
broadcast_last_dim
)
{
MS_LOG
(
INFO
)
<<
"This node broadcast last channel."
;
return
false
;
}
input_support_format
.
assign
(
input_num_
,
kOpFormat_FRAC_NZ
);
}
GenOutputSupportFormat
(
kOpFormat_FRAC_NZ
,
&
output_support_format
);
support_format
->
input_format
.
emplace_back
(
input_support_format
);
support_format
->
output_format
.
emplace_back
(
output_support_format
);
return
true
;
}
bool
TbeKernelBroadCastSelecter
::
IsBroadCastSupportNDC1HWC0
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
return
false
;
}
bool
TbeKernelBroadCastSelecter
::
Is4DShape
(
const
std
::
vector
<
size_t
>
&
shape
)
const
{
return
shape
.
size
()
==
kShape4dDims
;
}
bool
TbeKernelBroadCastSelecter
::
IsSameShape
()
const
{
auto
shape
=
input_shapes_
.
begin
();
for
(
const
auto
&
item
:
input_shapes_
)
{
if
(
shape
->
size
()
!=
item
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
shape
->
size
();
++
i
)
{
if
(
shape
->
at
(
i
)
!=
item
.
at
(
i
))
{
return
false
;
}
}
}
return
true
;
}
void
TbeKernelBroadCastSelecter
::
PadScalarShape
(
std
::
vector
<
size_t
>
*
shape
)
const
{
MS_EXCEPTION_IF_NULL
(
shape
);
if
(
shape
->
empty
())
{
shape
->
emplace_back
(
1
);
}
}
bool
TbeKernelBroadCastSelecter
::
IsScalarShape
(
const
std
::
vector
<
size_t
>
&
shape
)
const
{
return
(
shape
.
size
()
==
1
&&
shape
[
0
]
==
1
);
}
bool
TbeKernelBroadCastSelecter
::
HasScalarInput
()
const
{
bool
ret
=
false
;
for
(
const
auto
&
shape
:
input_shapes_
)
{
if
(
IsScalarShape
(
shape
))
{
ret
=
true
;
break
;
}
}
return
ret
;
}
void
TbeKernelBroadCastSelecter
::
GenOutputSupportFormat
(
const
std
::
string
&
support_format
,
SupportFormatItem
*
output_support_item
)
const
{
MS_EXCEPTION_IF_NULL
(
output_support_item
);
for
(
const
auto
&
shape
:
output_shapes_
)
{
if
(
IsScalarShape
(
shape
))
{
output_support_item
->
emplace_back
(
kOpFormat_DEFAULT
);
}
else
{
output_support_item
->
emplace_back
(
support_format
);
}
}
}
void
TbeKernelBroadCastSelecter
::
AssignSupportFormat
(
const
std
::
string
&
support_format_str
,
mindspore
::
kernel
::
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
SupportFormatItem
input_support_format
;
SupportFormatItem
output_support_format
;
input_support_format
.
assign
(
input_num_
,
support_format_str
);
output_support_format
.
assign
(
output_num_
,
support_format_str
);
support_format
->
input_format
.
emplace_back
(
input_support_format
);
support_format
->
output_format
.
emplace_back
(
output_support_format
);
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h
0 → 100644
浏览文件 @
caab25e0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace
mindspore
{
namespace
kernel
{
class
TbeKernelBroadCastSelecter
{
public:
explicit
TbeKernelBroadCastSelecter
(
CNodePtr
cnode_ptr
)
:
cnode_ptr_
(
std
::
move
(
cnode_ptr
))
{}
~
TbeKernelBroadCastSelecter
()
=
default
;
bool
GetShapeInfo
(
SupportFormat
*
support_format
);
bool
IsBroadCastSupport5HD
(
SupportFormat
*
support_format
)
const
;
bool
IsBroadCastSupportFracZ
(
SupportFormat
*
support_format
)
const
;
bool
IsBroadCastSupportC1HWNCoC0
(
SupportFormat
*
support_format
)
const
;
bool
IsBroadCastSupportFracNZ
(
SupportFormat
*
support_format
)
const
;
bool
IsBroadCastSupportNDC1HWC0
(
SupportFormat
*
support_format
)
const
;
private:
bool
IsSameShape
()
const
;
void
PadScalarShape
(
std
::
vector
<
size_t
>
*
shape
)
const
;
bool
Is4DShape
(
const
std
::
vector
<
size_t
>
&
shape
)
const
;
bool
IsScalarShape
(
const
std
::
vector
<
size_t
>
&
shape
)
const
;
bool
HasScalarInput
()
const
;
void
GenOutputSupportFormat
(
const
std
::
string
&
support_format
,
SupportFormatItem
*
output_support_item
)
const
;
void
AssignSupportFormat
(
const
std
::
string
&
support_format_str
,
SupportFormat
*
support_format
)
const
;
// broadcast
CNodePtr
cnode_ptr_
;
size_t
input_num_
{};
size_t
output_num_
{};
std
::
vector
<
std
::
vector
<
size_t
>>
input_shapes_
;
std
::
vector
<
std
::
vector
<
size_t
>>
output_shapes_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc
0 → 100644
浏览文件 @
caab25e0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include <string>
#include <vector>
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace
mindspore
{
namespace
kernel
{
constexpr
char
kKeepDims
[]
=
"keep_dims"
;
constexpr
char
kAxis
[]
=
"axis"
;
constexpr
char
kTypeInt32
[]
=
"Int32"
;
constexpr
size_t
kInputIndex_0
=
0
;
constexpr
size_t
kOutputIndex_0
=
0
;
constexpr
size_t
kChannelN
=
0
;
constexpr
size_t
kChannelC
=
1
;
constexpr
size_t
kReduceNZMinDim
=
3
;
bool
TbeKernelReduceSelecter
::
GetShapeInfo
(
SupportFormat
*
support_format
)
{
MS_EXCEPTION_IF_NULL
(
support_format
);
input_shape_
.
clear
();
output_shape_
.
clear
();
axis_
.
clear
();
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
cnode_ptr_
);
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
cnode_ptr_
);
if
(
input_num
!=
1
||
output_num
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Reduce operator only support one input/output, input num: "
<<
input_num
<<
", output num: "
<<
output_num
;
}
// get input/output shape
input_shape_
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
cnode_ptr_
,
kInputIndex_0
);
PadScalarShape
(
&
input_shape_
);
output_shape_
=
AnfAlgo
::
GetOutputInferShape
(
cnode_ptr_
,
kOutputIndex_0
);
PadScalarShape
(
&
output_shape_
);
// get keep dim attr
GetReduceAttrKeepDim
();
// get axis attr
GetReduceAttrAxis
();
AssignSupportFormat
(
kOpFormat_DEFAULT
,
support_format
);
return
true
;
}
bool
TbeKernelReduceSelecter
::
IsReduceSupport5HD
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
!
Is4DShape
(
input_shape_
))
{
return
false
;
}
if
(
!
keep_dims_
||
axis_
.
empty
())
{
return
false
;
}
auto
reduce_c_axis
=
std
::
any_of
(
axis_
.
begin
(),
axis_
.
end
(),
[](
const
size_t
&
elem
)
{
return
(
elem
==
kChannelC
);
});
if
(
reduce_c_axis
)
{
return
false
;
}
AssignSupportFormat
(
kOpFormat_NC1HWC0
,
support_format
);
return
true
;
}
bool
TbeKernelReduceSelecter
::
IsReduceSupportNDC1HWC0
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
// like to 5HD
return
false
;
}
bool
TbeKernelReduceSelecter
::
IsReduceSupportFracZ
(
SupportFormat
*
support_format
)
const
{
return
IsFracZAndC1HWNCoC0Common
(
kOpFormat_FRAC_Z
,
support_format
);
}
bool
TbeKernelReduceSelecter
::
IsReduceSupportC1HWNCoC0
(
SupportFormat
*
support_format
)
const
{
return
IsFracZAndC1HWNCoC0Common
(
kOpFormat_C1HWNCoC0
,
support_format
);
}
bool
TbeKernelReduceSelecter
::
IsReduceSupportFracNZ
(
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
input_shape_
.
size
()
<
kReduceNZMinDim
)
{
return
false
;
}
if
(
axis_
.
empty
())
{
return
false
;
}
auto
reduce_last_axis
=
std
::
any_of
(
axis_
.
begin
(),
axis_
.
end
(),
[
this
](
const
size_t
&
elem
)
{
return
(
elem
==
(
this
->
input_shape_
.
size
()
-
1
)
||
elem
==
(
this
->
input_shape_
.
size
()
-
2
));
});
if
(
reduce_last_axis
)
{
return
false
;
}
AssignSupportFormat
(
kOpFormat_FRAC_NZ
,
support_format
);
return
true
;
}
bool
TbeKernelReduceSelecter
::
IsFracZAndC1HWNCoC0Common
(
const
std
::
string
&
format
,
mindspore
::
kernel
::
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
if
(
!
Is4DShape
(
input_shape_
))
{
return
false
;
}
if
(
!
keep_dims_
||
axis_
.
empty
())
{
return
false
;
}
auto
reduce_n_c_axis
=
std
::
any_of
(
axis_
.
begin
(),
axis_
.
end
(),
[](
const
size_t
&
elem
)
{
return
(
elem
==
kChannelC
||
elem
==
kChannelN
);
});
if
(
reduce_n_c_axis
)
{
return
false
;
}
AssignSupportFormat
(
format
,
support_format
);
return
true
;
}
void
TbeKernelReduceSelecter
::
GetReduceAttrAxis
()
{
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
cnode_ptr_
);
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
axis
=
primitive
->
GetAttr
(
kAxis
);
if
(
axis
==
nullptr
)
{
MS_LOG
(
INFO
)
<<
"This node does't have axie attr."
;
return
;
}
auto
type
=
axis
->
type
();
MS_EXCEPTION_IF_NULL
(
type
);
std
::
vector
<
int
>
axis_list
;
if
(
type
->
ToString
()
==
kTypeInt32
)
{
axis_list
.
emplace_back
(
GetValue
<
int
>
(
axis
));
}
else
{
axis_list
=
GetValue
<
std
::
vector
<
int
>>
(
axis
);
}
for
(
const
auto
&
elem
:
axis_list
)
{
if
(
elem
<
0
)
{
axis_
.
emplace_back
(
input_shape_
.
size
()
+
elem
);
}
else
{
axis_
.
emplace_back
(
IntToSize
(
elem
));
}
}
}
void
TbeKernelReduceSelecter
::
GetReduceAttrKeepDim
()
{
if
(
!
AnfAlgo
::
HasNodeAttr
(
kKeepDims
,
cnode_ptr_
))
{
MS_LOG
(
INFO
)
<<
"This node does't have keep_attr."
;
keep_dims_
=
false
;
return
;
}
keep_dims_
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
cnode_ptr_
,
kKeepDims
);
}
void
TbeKernelReduceSelecter
::
AssignSupportFormat
(
const
std
::
string
&
support_format_str
,
mindspore
::
kernel
::
SupportFormat
*
support_format
)
const
{
MS_EXCEPTION_IF_NULL
(
support_format
);
SupportFormatItem
input_support_format
;
SupportFormatItem
output_support_format
;
input_support_format
.
emplace_back
(
support_format_str
);
output_support_format
.
emplace_back
(
support_format_str
);
support_format
->
input_format
.
emplace_back
(
input_support_format
);
support_format
->
output_format
.
emplace_back
(
output_support_format
);
}
bool
TbeKernelReduceSelecter
::
Is4DShape
(
const
std
::
vector
<
size_t
>
&
shape
)
const
{
return
shape
.
size
()
==
kShape4dDims
;
}
void
TbeKernelReduceSelecter
::
PadScalarShape
(
std
::
vector
<
size_t
>
*
shape
)
const
{
MS_EXCEPTION_IF_NULL
(
shape
);
if
(
shape
->
empty
())
{
shape
->
emplace_back
(
1
);
}
}
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h
0 → 100644
浏览文件 @
caab25e0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace
mindspore
{
namespace
kernel
{
class
TbeKernelReduceSelecter
{
public:
explicit
TbeKernelReduceSelecter
(
CNodePtr
cnode_ptr
)
:
cnode_ptr_
(
std
::
move
(
cnode_ptr
))
{}
~
TbeKernelReduceSelecter
()
=
default
;
bool
GetShapeInfo
(
SupportFormat
*
support_format
);
bool
IsReduceSupport5HD
(
SupportFormat
*
support_format
)
const
;
bool
IsReduceSupportNDC1HWC0
(
SupportFormat
*
support_format
)
const
;
bool
IsReduceSupportFracZ
(
SupportFormat
*
support_format
)
const
;
bool
IsReduceSupportC1HWNCoC0
(
SupportFormat
*
support_format
)
const
;
bool
IsReduceSupportFracNZ
(
SupportFormat
*
support_format
)
const
;
private:
bool
IsFracZAndC1HWNCoC0Common
(
const
std
::
string
&
format
,
SupportFormat
*
support_format
)
const
;
void
GetReduceAttrAxis
();
void
GetReduceAttrKeepDim
();
void
AssignSupportFormat
(
const
std
::
string
&
support_format_str
,
SupportFormat
*
support_format
)
const
;
bool
Is4DShape
(
const
std
::
vector
<
size_t
>
&
shape
)
const
;
void
PadScalarShape
(
std
::
vector
<
size_t
>
*
shape
)
const
;
CNodePtr
cnode_ptr_
;
std
::
vector
<
size_t
>
input_shape_
{};
std
::
vector
<
size_t
>
output_shape_
{};
std
::
vector
<
size_t
>
axis_
{};
bool
keep_dims_
=
false
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc
0 → 100644
浏览文件 @
caab25e0
此差异已折叠。
点击以展开。
mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h
0 → 100644
浏览文件 @
caab25e0
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
#define MINDSPORE_TBE_KERNEL_SELECT_H
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace
mindspore
{
namespace
kernel
{
void
TbeMetadataInfo
(
const
CNodePtr
&
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>
*
kernel_info_list
);
class
TbeKernelSelect
{
using
OpInfoPtr
=
std
::
shared_ptr
<
OpInfo
>
;
using
KernelBuildInfoIter
=
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>::
iterator
;
public:
TbeKernelSelect
(
CNodePtr
kernel_node
,
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>
*
kernel_info_list
);
~
TbeKernelSelect
()
=
default
;
void
TbeMetadataInfoEx
();
private:
void
GetCommonPatternKernelInfo
(
const
OpInfo
&
op_info
);
void
GetDynamicFormatPatternKernelInfo
(
const
OpInfo
&
op_info
);
void
GetAgnosticPatternKernelInfo
(
const
OpInfo
&
op_info
);
void
GetBroadcastPatternKernelInfo
(
const
OpInfo
&
op_info
);
void
GetReducePatternKernelInfo
(
const
OpInfo
&
op_info
);
void
FilterInVaildKernelInfo
();
bool
FilterInVaildShape
(
const
KernelBuildInfoIter
&
kernel_build_info_iter
);
static
bool
IsShapeMatchFormat
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
);
bool
TbeCheckSupported
(
const
KernelBuildInfoIter
&
kernel_build_info_iter
);
static
void
SetTbeBuildCommonInfo
(
const
OpInfo
&
op_info
,
KernelBuildInfo
::
KernelBuildInfoBuilder
*
builder
);
bool
GenBuilderItem
(
bool
is_input
,
size_t
kernel_build_info_index
,
size_t
real_io_tensor_num
,
const
std
::
vector
<
std
::
shared_ptr
<
OpIOInfo
>>
&
ios_info
,
const
std
::
vector
<
int
>
&
dyn_input_sizes
,
std
::
vector
<
std
::
string
>
*
formats
,
std
::
vector
<
TypeId
>
*
device_types
,
std
::
vector
<
std
::
vector
<
Axis
>>
*
reshape_types
);
static
void
StringToAxisVector
(
const
std
::
string
&
reshape_type_str
,
std
::
vector
<
Axis
>
*
reshape_type_vec
);
static
void
CreateNewOpInfo
(
const
OpInfo
&
op_info
,
const
SupportFormat
&
support_format
,
OpInfo
*
op_info_new
);
static
void
CreateNewOpIOInfo
(
const
OpIOInfo
&
op_io_info
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>
&
support_format_item
,
size_t
index
,
OpIOInfo
*
op_io_info_new
);
// op select(dynamic)
void
CreateNewOpInfo
(
const
mindspore
::
kernel
::
OpInfo
&
op_info
,
mindspore
::
kernel
::
OpInfo
*
op_info_new
);
static
void
CreateNewOpIOInfo
(
const
OpIOInfo
&
op_io_info
,
const
std
::
vector
<
std
::
string
>
&
support_dtype
,
const
std
::
vector
<
std
::
string
>
&
support_format
,
OpIOInfo
*
op_io_info_new
);
static
std
::
vector
<
std
::
string
>
SplitStrToVec
(
const
std
::
string
&
op_select_json_item
);
std
::
string
OpSelectFormat
();
static
void
PrintSupportedFormat
(
const
SupportFormat
&
support_format
);
private:
CNodePtr
cnode_ptr_
;
std
::
vector
<
std
::
shared_ptr
<
KernelBuildInfo
>>
*
kernel_info_list_
;
std
::
string
node_name_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
mindspore/ccsrc/parallel/ops_info/ops_utils.h
浏览文件 @
caab25e0
...
...
@@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg";
constexpr
char
BATCH_MATMUL
[]
=
"BatchMatMul"
;
constexpr
char
EXPAND_DIMS
[]
=
"ExpandDims"
;
constexpr
char
SQUARE
[]
=
"Square"
;
constexpr
char
BATCHMATMUL
[]
=
"BatchMatMul"
;
constexpr
char
TOPK
[]
=
"TopK"
;
constexpr
char
IN_TOPK
[]
=
"InTopK"
;
constexpr
char
PACK
[]
=
"Pack"
;
constexpr
char
GATHER_ND
[]
=
"GatherNd"
;
constexpr
char
UNSORTEF_SEGMENT_MIND
[]
=
"UnsortedSegmentMinD"
;
constexpr
char
UNSORTEF_SEGMENT_PRODD
[]
=
"UnsortedSegmentProdD"
;
// Parallel don't care
constexpr
char
TUPLE_GETITEM
[]
=
"tuple_getitem"
;
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
浏览文件 @
caab25e0
...
...
@@ -21,7 +21,6 @@
#include <vector>
#include "device/ascend/kernel_select_ascend.h"
#include "kernel/kernel_query.h"
#include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
浏览文件 @
caab25e0
...
...
@@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
return
nullptr
;
}
auto
node_name
=
AnfAlgo
::
GetCNodeName
(
node
);
if
(
node_name
!=
prim
::
KPrimTransData
->
name
()
||
node_name
!=
prim
::
kPrimCast
->
name
())
{
if
(
node_name
!=
prim
::
KPrimTransData
->
name
()
&&
node_name
!=
prim
::
kPrimCast
->
name
())
{
return
nullptr
;
}
auto
kernel_builder_info
=
AnfAlgo
::
GetSelectKernelBuildInfo
(
node
);
...
...
mindspore/ops/_op_impl/tbe/abs.py
浏览文件 @
caab25e0
...
...
@@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
F32_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
I32_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/abs_grad.py
浏览文件 @
caab25e0
...
...
@@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \
.
compute_cost
(
10
)
\
.
kernel_name
(
"abs_grad"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"y"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"dy"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"z"
,
False
,
"required"
,
"all"
)
\
...
...
mindspore/ops/_op_impl/tbe/add.py
浏览文件 @
caab25e0
...
...
@@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \
.
input
(
0
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
...
...
mindspore/ops/_op_impl/tbe/add_n.py
浏览文件 @
caab25e0
...
...
@@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \
.
attr
(
"n"
,
"required"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"dynamic"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
dtype_format
(
DataType
.
I32_FracZ
,
DataType
.
I32_FracZ
)
\
.
op_pattern
(
"broadcast"
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
F32_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
I32_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/batch_matmul.py
浏览文件 @
caab25e0
...
...
@@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"bias"
,
False
,
"optional"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_Default
,
DataType
.
F16_FracNZ
)
\
...
...
mindspore/ops/_op_impl/tbe/bias_add.py
浏览文件 @
caab25e0
...
...
@@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"bias"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
...
...
mindspore/ops/_op_impl/tbe/bn_training_reduce.py
浏览文件 @
caab25e0
...
...
@@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
,
reshape_type
=
"NC"
)
\
.
output
(
0
,
"sum"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"square_sum"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py
浏览文件 @
caab25e0
...
...
@@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.
input
(
5
,
"batch_mean"
,
False
,
"required"
,
"all"
)
\
.
input
(
6
,
"batch_variance"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
,
reshape_type
=
"NC"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
...
...
mindspore/ops/_op_impl/tbe/bn_training_update_grad.py
浏览文件 @
caab25e0
...
...
@@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.
input
(
3
,
"batch_variance"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"diff_scale"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"diff_offset"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
...
...
mindspore/ops/_op_impl/tbe/bn_training_update_v2.py
浏览文件 @
caab25e0
...
...
@@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
,
reshape_type
=
"NC"
)
\
.
output
(
1
,
"batch_mean"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"batch_variance"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F16_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
...
...
mindspore/ops/_op_impl/tbe/cast.py
浏览文件 @
caab25e0
...
...
@@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \
.
attr
(
"dst_type"
,
"required"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
I32_Default
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
dtype_format
(
DataType
.
BOOL_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
BOOL_None
,
DataType
.
U8_None
)
\
.
dtype_format
(
DataType
.
BOOL_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
BOOL_None
,
DataType
.
I32_None
)
\
.
dtype_format
(
DataType
.
I8_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
I8_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
I8_None
,
DataType
.
I32_None
)
\
.
dtype_format
(
DataType
.
U8_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
U8_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
U8_None
,
DataType
.
I32_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
BOOL_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
I8_None
)
\
.
dtype_format
(
DataType
.
I32_None
,
DataType
.
U8_None
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
U8_None
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
F32_None
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
I32_None
)
\
.
dtype_format
(
DataType
.
F32_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
F32_None
,
DataType
.
I32_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/concat.py
浏览文件 @
caab25e0
...
...
@@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \
.
attr
(
"axis"
,
"required"
,
"int"
,
"all"
)
\
.
input
(
0
,
"input_values"
,
False
,
"dynamic"
,
"all"
)
\
.
output
(
0
,
"output_data"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
BOOL_5HD
,
DataType
.
BOOL_5HD
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
...
...
mindspore/ops/_op_impl/tbe/conv2d.py
浏览文件 @
caab25e0
...
...
@@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.
compute_cost
(
10
)
\
.
kernel_name
(
"conv2d"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
attr
(
"stride"
,
"required"
,
"listInt"
,
"all"
)
\
.
attr
(
"pad_list"
,
"required"
,
"listInt"
,
"all"
)
\
.
attr
(
"dilation"
,
"required"
,
"listInt"
,
"all"
)
\
...
...
@@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.
input
(
2
,
"bias"
,
False
,
"optional"
,
"all"
)
\
.
input
(
3
,
"offset_w"
,
False
,
"optional"
,
"all"
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
I8_Default
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
F16_None
,
DataType
.
F16_None
,
DataType
.
I8_None
,
DataType
.
F16_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/dropout_do_mask.py
浏览文件 @
caab25e0
...
...
@@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
.
input
(
1
,
"mask"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"keep_prob"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
U8_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
U8_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/elu.py
浏览文件 @
caab25e0
...
...
@@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/erf.py
浏览文件 @
caab25e0
...
...
@@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/erfc.py
浏览文件 @
caab25e0
...
...
@@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/expm1.py
浏览文件 @
caab25e0
...
...
@@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/fused_mul_add.py
浏览文件 @
caab25e0
...
...
@@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"x3"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
dtype_format
(
DataType
.
I32_FracZ
,
DataType
.
I32_FracZ
,
DataType
.
I32_FracZ
,
DataType
.
I32_FracZ
)
\
...
...
mindspore/ops/_op_impl/tbe/layer_norm.py
浏览文件 @
caab25e0
...
...
@@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"mean"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"variance"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
...
...
mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py
浏览文件 @
caab25e0
...
...
@@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop")
.
input
(
3
,
"mean"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"pd_gamma"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"pd_beta"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
...
...
mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py
浏览文件 @
caab25e0
...
...
@@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
.
input
(
3
,
"mean"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"gamma"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"pd_x"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
...
...
mindspore/ops/_op_impl/tbe/mul.py
浏览文件 @
caab25e0
...
...
@@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"y"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"output"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
dtype_format
(
DataType
.
I32_FracZ
,
DataType
.
I32_FracZ
,
DataType
.
I32_FracZ
)
\
.
dtype_format
(
DataType
.
I32_FracNZ
,
DataType
.
I32_FracNZ
,
DataType
.
I32_FracNZ
)
\
.
dtype_format
(
DataType
.
I32_C1HWNCoC0
,
DataType
.
I32_C1HWNCoC0
,
DataType
.
I32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
,
DataType
.
F16_FracNZ
)
\
.
dtype_format
(
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
,
DataType
.
F32_FracNZ
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
None_None
,
DataType
.
None_None
,
DataType
.
None_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/real_div.py
浏览文件 @
caab25e0
...
...
@@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"y"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"z"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
op_pattern
(
"broadcast"
)
\
.
dtype_format
(
DataType
.
F16_None
,
DataType
.
F16_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
F32_None
,
DataType
.
F32_None
,
DataType
.
F32_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/reciprocal.py
浏览文件 @
caab25e0
...
...
@@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
.
partial_flag
(
True
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_NHWC
,
DataType
.
F16_NHWC
)
\
...
...
mindspore/ops/_op_impl/tbe/reduce_mean.py
浏览文件 @
caab25e0
...
...
@@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \
.
attr
(
"keep_dims"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F
32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F
16_5HD
,
DataType
.
F16_5HD
)
\
.
op_pattern
(
"reduce"
)
\
.
dtype_format
(
DataType
.
I8_None
,
DataType
.
I8_None
)
\
.
dtype_format
(
DataType
.
U8_None
,
DataType
.
U8_None
)
\
.
dtype_format
(
DataType
.
F
16_None
,
DataType
.
F16_None
)
\
.
dtype_format
(
DataType
.
F
32_None
,
DataType
.
F32_None
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/relu_grad_v2.py
浏览文件 @
caab25e0
...
...
@@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \
.
kernel_name
(
"relu_grad_v2"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"gradients"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"mask"
,
False
,
"re
re
quired"
,
"all"
)
\
.
input
(
1
,
"mask"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"backprops"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
U8_Default
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
U8_Default
,
DataType
.
F32_5HD
)
\
...
...
mindspore/ops/_op_impl/tbe/select.py
浏览文件 @
caab25e0
...
...
@@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \
.
input
(
1
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
...
...
mindspore/ops/_op_impl/tbe/sign.py
浏览文件 @
caab25e0
...
...
@@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
I32_5HD
,
DataType
.
I32_5HD
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/softmax_grad_ext.py
浏览文件 @
caab25e0
...
...
@@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \
.
input
(
1
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
...
...
mindspore/ops/_op_impl/tbe/softplus.py
浏览文件 @
caab25e0
...
...
@@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/softplus_grad.py
浏览文件 @
caab25e0
...
...
@@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \
.
input
(
1
,
"features"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"backprops"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
...
...
mindspore/ops/_op_impl/tbe/split_d.py
浏览文件 @
caab25e0
...
...
@@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \
.
attr
(
"output_num"
,
"required"
,
"int"
,
"all"
)
\
.
input
(
0
,
"value"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"output"
,
False
,
"dynamic"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
BOOL_Default
,
DataType
.
BOOL_Default
)
\
.
dtype_format
(
DataType
.
BOOL_NHWC
,
DataType
.
BOOL_NHWC
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
...
...
mindspore/ops/_op_impl/tbe/tensor_add.py
浏览文件 @
caab25e0
...
...
@@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \
.
input
(
0
,
"x1"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"x2"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
...
...
mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py
浏览文件 @
caab25e0
...
...
@@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"segment_ids"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
op_pattern
(
"dynamicFormat"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I32_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
I8_5HD
,
DataType
.
I32_5HD
,
DataType
.
I8_5HD
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
I32_Default
,
DataType
.
U8_Default
)
\
...
...
mindspore/ops/op_info_register.py
浏览文件 @
caab25e0
...
...
@@ -97,6 +97,7 @@ class RegOp:
"""
if
not
isinstance
(
value
,
str
):
raise
TypeError
(
"%s value must be str"
%
str
(
value
))
return
True
def
_is_int
(
self
,
value
):
"""
...
...
@@ -110,6 +111,7 @@ class RegOp:
"""
if
not
isinstance
(
value
,
int
):
raise
TypeError
(
"%s value must be int"
%
str
(
value
))
return
True
def
_is_bool
(
self
,
value
):
"""
...
...
@@ -123,6 +125,7 @@ class RegOp:
"""
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"%s value must be bool"
%
str
(
value
))
return
True
def
_check_param
(
self
,
param_list
,
key_list
,
fn_list
,
kwargs
):
"""
...
...
@@ -494,6 +497,7 @@ class DataType:
The current list below maybe not completed. If necessary, please add it.
"""
None_None
=
(
""
,
""
)
BOOL_None
=
(
"bool"
,
""
)
BOOL_Default
=
(
"bool"
,
"DefaultFormat"
)
BOOL_5HD
=
(
"bool"
,
"NC1HWC0"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录