提交 caab25e0 编写于 作者: J jjfeing

tbe select broadcast reduce dynamic

上级 553432c9
......@@ -20,7 +20,7 @@
#include "kernel/aicpu/aicpu_kernel_metadata.h"
#include "kernel/rts/rt_kernel_info.h"
#include "kernel/hccl/hccl_kernel_metadata.h"
#include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
......@@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) {
......@@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
TbeMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
......
......@@ -126,6 +126,8 @@ class OpInfo {
bool is_ref() const { return !ref_infos_.empty(); }
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
void ClearInputs() { (void)inputs_ptr_.clear(); }
void ClearOutputs() { (void)outputs_ptr_.clear(); }
private:
std::string op_name_;
......
......@@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name";
constexpr auto kPartialFlag = "partial_flag";
constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kDynamicFormat = "dynamic_format";
constexpr auto kDynamicFormat = "dynamicFormat";
constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce";
......@@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
{kFormatAgnostic, kBroadcastPattern},
{kBroadcast, kBroadcastPattern},
{kReduce, kReducePattern},
{kDynamicFormat, kDynamicFormatPattern}};
op_info->set_async_flag(obj.at(kAsyncFlag));
......@@ -108,14 +108,19 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
std::string op_pattern = obj.at(kOpPattern);
auto find_iter = kOpPatternMap.find(op_pattern);
if (find_iter == kOpPatternMap.end()) {
if (!op_pattern.empty()) {
MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
}
op_info->set_op_pattern(kCommonPattern);
} else {
op_info->set_op_pattern(find_iter->second);
}
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
}
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
......
......@@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "bool"},
{TypeId::kNumberTypeBool, "int8"},
};
const std::map<std::string, std::string> type_str_maps = {
......@@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) {
std::string TypeIdToString(TypeId type_id) {
auto iter = type_id_str_maps.find(type_id);
if (iter == type_id_str_maps.end()) {
MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id);
MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id);
}
return iter->second;
}
......
......@@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
} else {
// dtype : float16
auto tensor_dtype =
std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index)));
MS_EXCEPTION_IF_NULL(tensor_dtype);
std::string dtype = tensor_dtype->element()->ToString();
dtype = tbe::DtypeToString(dtype);
// format
std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index);
if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
} else if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
}
nlohmann::json input_desc_json;
input_desc_json["dtype"] = dtype;
input_desc_json["name"] = op_input_name + std::to_string(input_i);
auto dtype = GetDeviceInputType(anf_node, real_input_index);
auto format = GetDeviceInputFormat(anf_node, real_input_index);
auto shape = GetDeviceInputShape(anf_node, real_input_index);
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
nlohmann::json input_desc_json;
input_desc_json["dtype"] = dtype;
input_desc_json["name"] = op_input_name + std::to_string(input_i);
input_desc_json["ori_shape"] = ori_shape;
input_desc_json["ori_format"] = kOpFormat_NCHW;
auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
if (shape.empty()) {
shape.emplace_back(1);
}
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
input_desc_json["shape"] = ori_shape;
input_desc_json["format"] = kOpFormat_NCHW;
} else {
input_desc_json["shape"] = shape;
input_desc_json["format"] = format;
}
input_desc_json["shape"] = shape;
input_desc_json["format"] = format;
input_desc_json["valid"] = value;
input_desc_json["param_type"] = input_ptr->param_type();
input_list->emplace_back(input_desc_json);
......@@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
MS_EXCEPTION_IF_NULL(output_idx);
MS_EXCEPTION_IF_NULL(output_list);
for (size_t i = 0; i < output_obj_num; i++) {
nlohmann::json output_obj;
auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx)));
std::string dtype = type_ptr->element()->ToString();
dtype = tbe::DtypeToString(dtype);
std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx);
if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
} else if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
}
std::vector<size_t> ori_shape;
if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) {
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
auto shape = GetDeviceOutputShape(anf_node, *output_idx);
std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
} else {
ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
}
nlohmann::json output_obj;
output_obj["dtype"] = dtype;
auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx);
if (shape.empty()) {
shape.emplace_back(1);
}
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
output_obj["shape"] = ori_shape;
output_obj["format"] = kOpFormat_NCHW;
} else {
output_obj["shape"] = shape;
output_obj["format"] = format;
}
output_obj["shape"] = shape;
output_obj["format"] = format;
output_obj["ori_shape"] = ori_shape;
output_obj["ori_format"] = kOpFormat_NCHW;
output_obj["name"] = output_ptr->name();
output_obj["valid"] = true;
output_obj["param_type"] = output_ptr->param_type();
output_list->emplace_back(output_obj);
(*output_idx)++;
}
......@@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
}
}
std::vector<size_t> TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::vector<size_t> shape;
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index);
} else {
shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index);
}
if (shape.empty()) {
shape.emplace_back(1);
}
return shape;
}
std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
TypeId type_id;
if (creater_type_ == OP_SELECT_FORMAT) {
type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index);
} else {
type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
}
return tbe::TypeIdToString(type_id);
}
std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::string format = kOpFormat_NCHW;
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
format = AnfAlgo::GetInputFormat(anf_node, real_index);
if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
} else if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
}
}
return format;
}
std::vector<size_t> TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::vector<size_t> shape;
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
shape = AnfAlgo::GetOutputInferShape(anf_node, real_index);
} else {
shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index);
}
if (shape.empty()) {
shape.emplace_back(1);
}
return shape;
}
std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
TypeId type_id;
if (creater_type_ == OP_SELECT_FORMAT) {
type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index);
} else {
type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index);
}
return tbe::TypeIdToString(type_id);
}
std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::string format = kOpFormat_NCHW;
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
format = AnfAlgo::GetOutputFormat(anf_node, real_index);
if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
} else if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
}
}
return format;
}
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
std::vector<size_t> *output_size_list) {
if (input_size_list == nullptr || output_size_list == nullptr) {
......
......@@ -93,7 +93,7 @@ class TbeKernelJsonCreator {
nlohmann::json *outputs_json);
bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *attrs_json);
void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
std::vector<nlohmann::json> *input_list);
......@@ -105,6 +105,13 @@ class TbeKernelJsonCreator {
void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
std::vector<nlohmann::json> *output_list);
std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
std::vector<size_t> GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
kCreaterType creater_type_;
std::string json_name_;
std::string json_info_;
......
......@@ -230,7 +230,7 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
task_iter->second.output_size_list, kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod);
if (set_kernel_mod) {
AnfAlgo ::SetKernelMod(kernel_mod, task_iter->second.node);
AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node);
}
auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod);
(void)task_map_.erase(task_iter);
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -13,20 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
#define MINDSPORE_TBE_KERNEL_SELECT_H
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
struct SupportFormat {
std::vector<std::vector<std::string>> input_format;
std::vector<std::vector<std::string>> output_format;
};
using SupportFormatItem = std::vector<std::string>;
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr char kDynInputKey[] = "dyn_input_sizes";
constexpr size_t kInputIndex_0 = 0;
constexpr size_t kChannelN = 0;
constexpr size_t kChannelC = 1;
constexpr size_t kAlignmented16 = 16;
// 1. all shape no scalar and same
// 2. part scalar : no_scalar (shape size > xxx && alig xxx)
// 3. all no_scalar and not same (broad cast xxx dim)
bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
MS_EXCEPTION_IF_NULL(support_format);
input_num_ = 0;
output_num_ = 0;
input_shapes_.clear();
output_shapes_.clear();
if (AnfAlgo::HasNodeAttr(kDynInputKey, cnode_ptr_)) {
MS_LOG(INFO) << "This broadcast node has dynamic input.";
auto dynamic_size_vec = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynInputKey);
if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) {
MS_LOG(EXCEPTION) << "dynamic attr set error, please check.";
}
auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0);
PadScalarShape(&dynamic_input_shape0_);
input_shapes_.emplace_back(dynamic_input_shape0_);
input_num_ = 1;
} else {
input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_);
for (size_t i = 0; i < input_num_; ++i) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
PadScalarShape(&input_shape);
input_shapes_.emplace_back(input_shape);
}
}
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
for (size_t i = 0; i < output_num_; ++i) {
auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
PadScalarShape(&output);
output_shapes_.emplace_back(output);
}
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_NC1HWC0, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
}
if (shape[kChannelC] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_NC1HWC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return false;
}
}
auto shape_tmp = input_shapes_[0];
auto broadcast_c_axis = std::any_of(
input_shapes_.begin(), input_shapes_.end(),
[&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); });
if (broadcast_c_axis) {
MS_LOG(INFO) << "This node broadcast c channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_NC1HWC0);
}
GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_FRAC_Z, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
}
if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_FRAC_Z);
}
}
} else {
return false;
}
GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
}
if (shape[kChannelN] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return false;
}
}
auto shape_tmp = input_shapes_[0];
auto broadcast_nc_axis =
std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) {
return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN));
});
if (broadcast_nc_axis) {
MS_LOG(INFO) << "This node broadcast n || c channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
}
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_FRAC_NZ, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (shape.size() < kShape2dDims) {
return false;
}
if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
}
}
} else {
auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(),
[](const std::vector<size_t> &elem) { return elem.size() < kShape2dDims; });
if (less_2dims) {
MS_LOG(INFO) << "This node dim less 2.";
return false;
}
auto shape_tmp = input_shapes_[0];
auto broadcast_last_dim =
std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) {
return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) ||
(shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2));
});
if (broadcast_last_dim) {
MS_LOG(INFO) << "This node broadcast last channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_FRAC_NZ);
}
GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
return false;
}
bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const {
return shape.size() == kShape4dDims;
}
bool TbeKernelBroadCastSelecter::IsSameShape() const {
auto shape = input_shapes_.begin();
for (const auto &item : input_shapes_) {
if (shape->size() != item.size()) {
return false;
}
for (size_t i = 0; i < shape->size(); ++i) {
if (shape->at(i) != item.at(i)) {
return false;
}
}
}
return true;
}
void TbeKernelBroadCastSelecter::PadScalarShape(std::vector<size_t> *shape) const {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {
shape->emplace_back(1);
}
}
bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector<size_t> &shape) const {
return (shape.size() == 1 && shape[0] == 1);
}
bool TbeKernelBroadCastSelecter::HasScalarInput() const {
bool ret = false;
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
ret = true;
break;
}
}
return ret;
}
void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format,
SupportFormatItem *output_support_item) const {
MS_EXCEPTION_IF_NULL(output_support_item);
for (const auto &shape : output_shapes_) {
if (IsScalarShape(shape)) {
output_support_item->emplace_back(kOpFormat_DEFAULT);
} else {
output_support_item->emplace_back(support_format);
}
}
}
void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str,
mindspore::kernel::SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
input_support_format.assign(input_num_, support_format_str);
output_support_format.assign(output_num_, support_format_str);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
class TbeKernelBroadCastSelecter {
public:
explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelBroadCastSelecter() = default;
bool GetShapeInfo(SupportFormat *support_format);
bool IsBroadCastSupport5HD(SupportFormat *support_format) const;
bool IsBroadCastSupportFracZ(SupportFormat *support_format) const;
bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const;
bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const;
bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const;
private:
bool IsSameShape() const;
void PadScalarShape(std::vector<size_t> *shape) const;
bool Is4DShape(const std::vector<size_t> &shape) const;
bool IsScalarShape(const std::vector<size_t> &shape) const;
bool HasScalarInput() const;
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
// broadcast
CNodePtr cnode_ptr_;
size_t input_num_{};
size_t output_num_{};
std::vector<std::vector<size_t>> input_shapes_;
std::vector<std::vector<size_t>> output_shapes_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include <string>
#include <vector>
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr char kKeepDims[] = "keep_dims";
constexpr char kAxis[] = "axis";
constexpr char kTypeInt32[] = "Int32";
constexpr size_t kInputIndex_0 = 0;
constexpr size_t kOutputIndex_0 = 0;
constexpr size_t kChannelN = 0;
constexpr size_t kChannelC = 1;
constexpr size_t kReduceNZMinDim = 3;
bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) {
MS_EXCEPTION_IF_NULL(support_format);
input_shape_.clear();
output_shape_.clear();
axis_.clear();
auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
if (input_num != 1 || output_num != 1) {
MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num
<< ", output num: " << output_num;
}
// get input/output shape
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0);
PadScalarShape(&input_shape_);
output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0);
PadScalarShape(&output_shape_);
// get keep dim attr
GetReduceAttrKeepDim();
// get axis attr
GetReduceAttrAxis();
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (!Is4DShape(input_shape_)) {
return false;
}
if (!keep_dims_ || axis_.empty()) {
return false;
}
auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); });
if (reduce_c_axis) {
return false;
}
AssignSupportFormat(kOpFormat_NC1HWC0, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
// like to 5HD
return false;
}
bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const {
return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format);
}
bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const {
return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format);
}
bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (input_shape_.size() < kReduceNZMinDim) {
return false;
}
if (axis_.empty()) {
return false;
}
auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) {
return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2));
});
if (reduce_last_axis) {
return false;
}
AssignSupportFormat(kOpFormat_FRAC_NZ, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format,
mindspore::kernel::SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (!Is4DShape(input_shape_)) {
return false;
}
if (!keep_dims_ || axis_.empty()) {
return false;
}
auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(),
[](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); });
if (reduce_n_c_axis) {
return false;
}
AssignSupportFormat(format, support_format);
return true;
}
void TbeKernelReduceSelecter::GetReduceAttrAxis() {
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
auto axis = primitive->GetAttr(kAxis);
if (axis == nullptr) {
MS_LOG(INFO) << "This node does't have axie attr.";
return;
}
auto type = axis->type();
MS_EXCEPTION_IF_NULL(type);
std::vector<int> axis_list;
if (type->ToString() == kTypeInt32) {
axis_list.emplace_back(GetValue<int>(axis));
} else {
axis_list = GetValue<std::vector<int>>(axis);
}
for (const auto &elem : axis_list) {
if (elem < 0) {
axis_.emplace_back(input_shape_.size() + elem);
} else {
axis_.emplace_back(IntToSize(elem));
}
}
}
void TbeKernelReduceSelecter::GetReduceAttrKeepDim() {
if (!AnfAlgo::HasNodeAttr(kKeepDims, cnode_ptr_)) {
MS_LOG(INFO) << "This node does't have keep_attr.";
keep_dims_ = false;
return;
}
keep_dims_ = AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kKeepDims);
}
void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str,
mindspore::kernel::SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
input_support_format.emplace_back(support_format_str);
output_support_format.emplace_back(support_format_str);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
}
bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; }
void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {
shape->emplace_back(1);
}
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
class TbeKernelReduceSelecter {
public:
explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelReduceSelecter() = default;
bool GetShapeInfo(SupportFormat *support_format);
bool IsReduceSupport5HD(SupportFormat *support_format) const;
bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const;
bool IsReduceSupportFracZ(SupportFormat *support_format) const;
bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const;
bool IsReduceSupportFracNZ(SupportFormat *support_format) const;
private:
bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const;
void GetReduceAttrAxis();
void GetReduceAttrKeepDim();
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
bool Is4DShape(const std::vector<size_t> &shape) const;
void PadScalarShape(std::vector<size_t> *shape) const;
CNodePtr cnode_ptr_;
std::vector<size_t> input_shape_{};
std::vector<size_t> output_shape_{};
std::vector<size_t> axis_{};
bool keep_dims_ = false;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
#define MINDSPORE_TBE_KERNEL_SELECT_H
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
class TbeKernelSelect {
using OpInfoPtr = std::shared_ptr<OpInfo>;
using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator;
public:
TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
~TbeKernelSelect() = default;
void TbeMetadataInfoEx();
private:
void GetCommonPatternKernelInfo(const OpInfo &op_info);
void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info);
void GetAgnosticPatternKernelInfo(const OpInfo &op_info);
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
void GetReducePatternKernelInfo(const OpInfo &op_info);
void FilterInVaildKernelInfo();
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter);
static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);
static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder);
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, const std::vector<int> &dyn_input_sizes,
std::vector<std::string> *formats, std::vector<TypeId> *device_types,
std::vector<std::vector<Axis>> *reshape_types);
static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new);
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info,
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
OpIOInfo *op_io_info_new);
// op select(dynamic)
void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new);
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector<std::string> &support_dtype,
const std::vector<std::string> &support_format, OpIOInfo *op_io_info_new);
static std::vector<std::string> SplitStrToVec(const std::string &op_select_json_item);
std::string OpSelectFormat();
static void PrintSupportedFormat(const SupportFormat &support_format);
private:
CNodePtr cnode_ptr_;
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_;
std::string node_name_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
......@@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg";
constexpr char BATCH_MATMUL[] = "BatchMatMul";
constexpr char EXPAND_DIMS[] = "ExpandDims";
constexpr char SQUARE[] = "Square";
constexpr char BATCHMATMUL[] = "BatchMatMul";
constexpr char TOPK[] = "TopK";
constexpr char IN_TOPK[] = "InTopK";
constexpr char PACK[] = "Pack";
constexpr char GATHER_ND[] = "GatherNd";
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
// Parallel don't care
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
......
......@@ -21,7 +21,6 @@
#include <vector>
#include "device/ascend/kernel_select_ascend.h"
#include "kernel/kernel_query.h"
#include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
......
......@@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
return nullptr;
}
auto node_name = AnfAlgo::GetCNodeName(node);
if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) {
if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) {
return nullptr;
}
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
......
......@@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \
.op_pattern("formatAgnostic") \
.input(0, "x", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.get_op_info()
......
......@@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \
.compute_cost(10) \
.kernel_name("abs_grad") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "y", None, "required", None) \
.input(1, "dy", None, "required", None) \
.output(0, "z", False, "required", "all") \
......
......@@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
......
......@@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \
.attr("n", "required", "int", "all") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.get_op_info()
......
......@@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
.input(1, "x2", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
......
......@@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \
.input(0, "x", False, "required", "all") \
.input(1, "bias", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
......
......@@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.input(0, "x", False, "required", "all", reshape_type="NC") \
.output(0, "sum", False, "required", "all") \
.output(1, "square_sum", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.input(5, "batch_mean", False, "required", "all") \
.input(6, "batch_variance", False, "required", "all") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
......
......@@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.input(3, "batch_variance", False, "required", "all") \
.output(0, "diff_scale", False, "required", "all") \
.output(1, "diff_offset", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
......
......@@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.output(1, "batch_mean", False, "required", "all") \
.output(2, "batch_variance", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
......
......@@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \
.attr("dst_type", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.BOOL_None, DataType.F16_None) \
.dtype_format(DataType.BOOL_None, DataType.U8_None) \
.dtype_format(DataType.BOOL_None, DataType.F32_None) \
.dtype_format(DataType.BOOL_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.F32_None) \
.dtype_format(DataType.I8_None, DataType.I32_None) \
.dtype_format(DataType.U8_None, DataType.F16_None) \
.dtype_format(DataType.U8_None, DataType.F32_None) \
.dtype_format(DataType.U8_None, DataType.I32_None) \
.dtype_format(DataType.I32_None, DataType.BOOL_None) \
.dtype_format(DataType.I32_None, DataType.F16_None) \
.dtype_format(DataType.I32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F32_None) \
.dtype_format(DataType.F16_None, DataType.I32_None) \
.dtype_format(DataType.F32_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.I32_None) \
.get_op_info()
......
......@@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \
.attr("axis", "required", "int", "all") \
.input(0, "input_values", False, "dynamic", "all") \
.output(0, "output_data", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
......
......@@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.compute_cost(10) \
.kernel_name("conv2d") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
......@@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default,
DataType.F16_5HD) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \
.get_op_info()
......
......@@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
.input(1, "mask", False, "required", "all") \
.input(2, "keep_prob", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
......
......@@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \
.output(0, "y", False, "required", "all") \
.output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop")
.input(3, "mean", False, "required", "all") \
.output(0, "pd_gamma", False, "required", "all") \
.output(1, "pd_beta", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
.input(3, "mean", False, "required", "all") \
.input(4, "gamma", False, "required", "all") \
.output(0, "pd_x", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
......
......@@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
......
......@@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.op_pattern("reduce") \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \
.kernel_name("relu_grad_v2") \
.partial_flag(True) \
.input(0, "gradients", False, "required", "all") \
.input(1, "mask", False, "rerequired", "all") \
.input(1, "mask", False, "required", "all") \
.output(0, "backprops", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \
......
......@@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
......
......@@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \
.input(0, "x", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.get_op_info()
......
......@@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \
.attr("output_num", "required", "int", "all") \
.input(0, "value", False, "required", "all") \
.output(0, "output", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
......
......@@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
......
......@@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
.input(0, "x", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
......
......@@ -97,6 +97,7 @@ class RegOp:
"""
if not isinstance(value, str):
raise TypeError("%s value must be str" % str(value))
return True
def _is_int(self, value):
"""
......@@ -110,6 +111,7 @@ class RegOp:
"""
if not isinstance(value, int):
raise TypeError("%s value must be int" % str(value))
return True
def _is_bool(self, value):
"""
......@@ -123,6 +125,7 @@ class RegOp:
"""
if not isinstance(value, bool):
raise TypeError("%s value must be bool" % str(value))
return True
def _check_param(self, param_list, key_list, fn_list, kwargs):
"""
......@@ -494,6 +497,7 @@ class DataType:
The current list below maybe not completed. If necessary, please add it.
"""
None_None = ("", "")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册