提交 2005ecc2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1949 in order to support dynamic broadcast reduce format with operator

Merge pull request !1949 from jjfeing/master
......@@ -20,7 +20,7 @@
#include "kernel/aicpu/aicpu_kernel_metadata.h"
#include "kernel/rts/rt_kernel_info.h"
#include "kernel/hccl/hccl_kernel_metadata.h"
#include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
......@@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) {
......@@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
TbeMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
......
......@@ -126,6 +126,8 @@ class OpInfo {
bool is_ref() const { return !ref_infos_.empty(); }
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
void ClearInputs() { (void)inputs_ptr_.clear(); }
void ClearOutputs() { (void)outputs_ptr_.clear(); }
private:
std::string op_name_;
......
......@@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name";
constexpr auto kPartialFlag = "partial_flag";
constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kDynamicFormat = "dynamic_format";
constexpr auto kDynamicFormat = "dynamicFormat";
constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce";
......@@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
{kFormatAgnostic, kBroadcastPattern},
{kBroadcast, kBroadcastPattern},
{kReduce, kReducePattern},
{kDynamicFormat, kDynamicFormatPattern}};
op_info->set_async_flag(obj.at(kAsyncFlag));
......@@ -108,14 +108,19 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
std::string op_pattern = obj.at(kOpPattern);
auto find_iter = kOpPatternMap.find(op_pattern);
if (find_iter == kOpPatternMap.end()) {
if (!op_pattern.empty()) {
MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
}
op_info->set_op_pattern(kCommonPattern);
} else {
op_info->set_op_pattern(find_iter->second);
}
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
}
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
......
......@@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "bool"},
{TypeId::kNumberTypeBool, "int8"},
};
const std::map<std::string, std::string> type_str_maps = {
......@@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) {
std::string TypeIdToString(TypeId type_id) {
auto iter = type_id_str_maps.find(type_id);
if (iter == type_id_str_maps.end()) {
MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id);
MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id);
}
return iter->second;
}
......
......@@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
} else {
// dtype : float16
auto tensor_dtype =
std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index)));
MS_EXCEPTION_IF_NULL(tensor_dtype);
std::string dtype = tensor_dtype->element()->ToString();
dtype = tbe::DtypeToString(dtype);
// format
std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index);
if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
} else if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
}
nlohmann::json input_desc_json;
input_desc_json["dtype"] = dtype;
input_desc_json["name"] = op_input_name + std::to_string(input_i);
auto dtype = GetDeviceInputType(anf_node, real_input_index);
auto format = GetDeviceInputFormat(anf_node, real_input_index);
auto shape = GetDeviceInputShape(anf_node, real_input_index);
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
nlohmann::json input_desc_json;
input_desc_json["dtype"] = dtype;
input_desc_json["name"] = op_input_name + std::to_string(input_i);
input_desc_json["ori_shape"] = ori_shape;
input_desc_json["ori_format"] = kOpFormat_NCHW;
auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
if (shape.empty()) {
shape.emplace_back(1);
}
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
input_desc_json["shape"] = ori_shape;
input_desc_json["format"] = kOpFormat_NCHW;
} else {
input_desc_json["shape"] = shape;
input_desc_json["format"] = format;
}
input_desc_json["shape"] = shape;
input_desc_json["format"] = format;
input_desc_json["valid"] = value;
input_desc_json["param_type"] = input_ptr->param_type();
input_list->emplace_back(input_desc_json);
......@@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
MS_EXCEPTION_IF_NULL(output_idx);
MS_EXCEPTION_IF_NULL(output_list);
for (size_t i = 0; i < output_obj_num; i++) {
nlohmann::json output_obj;
auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx)));
std::string dtype = type_ptr->element()->ToString();
dtype = tbe::DtypeToString(dtype);
std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx);
if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
} else if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
}
std::vector<size_t> ori_shape;
if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) {
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
auto shape = GetDeviceOutputShape(anf_node, *output_idx);
std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
} else {
ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
}
nlohmann::json output_obj;
output_obj["dtype"] = dtype;
auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx);
if (shape.empty()) {
shape.emplace_back(1);
}
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
output_obj["shape"] = ori_shape;
output_obj["format"] = kOpFormat_NCHW;
} else {
output_obj["shape"] = shape;
output_obj["format"] = format;
}
output_obj["shape"] = shape;
output_obj["format"] = format;
output_obj["ori_shape"] = ori_shape;
output_obj["ori_format"] = kOpFormat_NCHW;
output_obj["name"] = output_ptr->name();
output_obj["valid"] = true;
output_obj["param_type"] = output_ptr->param_type();
output_list->emplace_back(output_obj);
(*output_idx)++;
}
......@@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
}
}
std::vector<size_t> TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::vector<size_t> shape;
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index);
} else {
shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index);
}
if (shape.empty()) {
shape.emplace_back(1);
}
return shape;
}
std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
TypeId type_id;
if (creater_type_ == OP_SELECT_FORMAT) {
type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index);
} else {
type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
}
return tbe::TypeIdToString(type_id);
}
std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::string format = kOpFormat_NCHW;
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
format = AnfAlgo::GetInputFormat(anf_node, real_index);
if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
} else if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
}
}
return format;
}
std::vector<size_t> TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::vector<size_t> shape;
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
shape = AnfAlgo::GetOutputInferShape(anf_node, real_index);
} else {
shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index);
}
if (shape.empty()) {
shape.emplace_back(1);
}
return shape;
}
std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
TypeId type_id;
if (creater_type_ == OP_SELECT_FORMAT) {
type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index);
} else {
type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index);
}
return tbe::TypeIdToString(type_id);
}
std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::string format = kOpFormat_NCHW;
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
format = AnfAlgo::GetOutputFormat(anf_node, real_index);
if (format == kOpFormat_FRAC_Z) {
format = kOpFormat_FRACTAL_Z;
} else if (format == kOpFormat_DEFAULT) {
format = kOpFormat_NCHW;
}
}
return format;
}
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
std::vector<size_t> *output_size_list) {
if (input_size_list == nullptr || output_size_list == nullptr) {
......
......@@ -93,7 +93,7 @@ class TbeKernelJsonCreator {
nlohmann::json *outputs_json);
bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *attrs_json);
void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
std::vector<nlohmann::json> *input_list);
......@@ -105,6 +105,13 @@ class TbeKernelJsonCreator {
void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
std::vector<nlohmann::json> *output_list);
std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
std::vector<size_t> GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const;
std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
kCreaterType creater_type_;
std::string json_name_;
std::string json_info_;
......
......@@ -230,7 +230,7 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
task_iter->second.output_size_list, kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod);
if (set_kernel_mod) {
AnfAlgo ::SetKernelMod(kernel_mod, task_iter->second.node);
AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node);
}
auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod);
(void)task_map_.erase(task_iter);
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select.h"
#include <unordered_map>
#include <memory>
#include <map>
#include <set>
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "kernel/tbe/tbe_kernel_build.h"
#include "nlohmann/json.hpp"
#include "common/utils.h"
#include "utils/context/ms_context.h"
#include "kernel/tbe/tbe_python_funcs.h"
#include "pre_activate/common/helper.h"
#include "kernel/tbe/tbe_convert_utils.h"
namespace mindspore {
namespace kernel {
constexpr auto kName = "name";
constexpr auto kDtype = "dtype";
constexpr auto kFormat = "format";
constexpr auto kPrefixInput = "input";
constexpr auto kPrefixOutput = "output";
const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"},
{"NHWC", "DefaultFormat"},
{"ND", "DefaultFormat"},
{"FRACTAL_Z", "FracZ"},
{"NDHWC", "DefaultFormat"}};
static const std::vector<std::string> CHECK_SUPPORTED_OPTYPE{
"MatMul", "BatchMatMul", "TopK", "InTopK", "Pack", "GatherNd", "UnsortedSegmentMinD", "UnsortedSegmentProdD", "Cast"};
bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
auto iter = std::find(CHECK_SUPPORTED_OPTYPE.begin(), CHECK_SUPPORTED_OPTYPE.end(), op_name);
if (iter == CHECK_SUPPORTED_OPTYPE.end()) {
MS_LOG(DEBUG) << "Op " << op_name << "this op does not need to check op supported.";
return true;
}
// replace kernel_info with current kernel info
auto ori_select_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(anf_node);
AnfAlgo::SetSelectKernelBuildInfo(select_kernel_build_info, anf_node.get());
nlohmann::json kernel_json;
TbeKernelJsonCreator creator(CHECK_SUPPORTED);
bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json);
if (!ret) {
MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed";
AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get());
return false;
}
ret = TbePythonFuncs::CheckSupported(kernel_json);
AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get());
return ret;
}
bool CheckJsonItemValidity(const nlohmann::json &json_obj, const std::string &key_name,
const std::vector<std::string> &keys) {
if (!json_obj[key_name].is_object()) {
MS_LOG(DEBUG) << key_name << "is not an object!";
return false;
}
for (auto key : keys) {
if (json_obj[key_name].find(key) == json_obj[key_name].end()) {
MS_LOG(DEBUG) << "Key" << key << "of " << key_name << " is not found!";
return false;
}
}
return true;
}
std::vector<std::string> SplitStr(const std::string &string, const std::string &sep) {
std::vector<std::string> result;
size_t start = 0;
size_t index = string.find(sep, start);
std::string substr;
while (index != std::string::npos) {
if (string.size() > start) {
substr = string.substr(start, index - start);
}
(void)substr.erase(0, substr.find_first_not_of(' '));
(void)substr.erase(substr.find_last_not_of(' ') + 1);
auto iter = DYNAMIC_FORMAT_MAP.find(substr);
if (iter != DYNAMIC_FORMAT_MAP.end()) {
substr = iter->second;
}
result.push_back(substr);
start = index + sep.size();
index = string.find(sep, start);
}
if (string.size() > start) {
substr = string.substr(start);
}
(void)substr.erase(0, substr.find_first_not_of(' '));
(void)substr.erase(substr.find_last_not_of(' ') + 1);
auto iter = DYNAMIC_FORMAT_MAP.find(substr);
if (iter != DYNAMIC_FORMAT_MAP.end()) {
substr = iter->second;
}
result.push_back(substr);
return result;
}
void ConvertFormatDtype(const std::string &format, const std::string &dtype, const std::shared_ptr<OpIOInfo> &io_info) {
MS_EXCEPTION_IF_NULL(io_info);
std::vector<std::string> format_vec = SplitStr(format, ",");
std::vector<std::string> dtype_vec = SplitStr(dtype, ",");
io_info->set_formats(format_vec);
io_info->set_dtypes(dtype_vec);
}
bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_ptr<OpIOInfo>> *const inputs,
std::vector<std::shared_ptr<OpIOInfo>> *const outputs) {
nlohmann::json json_obj = nlohmann::json::parse(jsonStr);
if (!json_obj.is_object()) {
MS_LOG(DEBUG) << "JsonStr is not an object, the jsonStr is:" << jsonStr;
return false;
}
std::vector<std::string> keys = {kName, kDtype, kFormat};
for (const auto &item : json_obj.items()) {
std::string key_name;
key_name = item.key();
if (key_name.empty()) {
MS_LOG(DEBUG) << "Key name is empty!";
return false;
}
if (!CheckJsonItemValidity(json_obj, key_name, keys)) {
return false;
}
if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) {
std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>();
MS_EXCEPTION_IF_NULL(input);
input->set_name(json_obj[key_name].at(kName));
ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input);
inputs->emplace_back(input);
} else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) {
std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>();
MS_EXCEPTION_IF_NULL(output);
output->set_name(json_obj[key_name].at(kName));
ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), output);
outputs->emplace_back(output);
} else {
MS_LOG(DEBUG) << "Key name:" << key_name << " is undefined!";
return false;
}
}
return true;
}
std::string OpSelectFormat(const std::shared_ptr<AnfNode> &anf_node) {
nlohmann::json kernel_json;
std::string res_json_str;
TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json);
if (!ret) {
MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed";
return res_json_str;
}
res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json);
MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str;
return res_json_str;
}
void SetTidyInputsInfo(const std::shared_ptr<AnfNode> &anf_node,
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
const std::vector<std::shared_ptr<OpIOInfo>> &inputs) {
std::vector<TypeId> inputs_type;
std::vector<std::string> inputs_format;
std::vector<int> dyn_input_sizes;
size_t dyn_input_idx = 0;
size_t kernel_info_index = 0;
size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node);
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
}
for (size_t i = 0; i < inputs.size(); i++) {
MS_EXCEPTION_IF_NULL(inputs[i]);
std::string param_type = inputs[i]->param_type();
if (i >= real_input_num) {
MS_LOG(INFO) << "Input index: " << i << " is out of real_input_num:" << real_input_num;
continue;
}
auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, i);
auto format = kOpFormat_DEFAULT;
if (param_type == "dynamic") {
if (!dyn_input_sizes.empty()) {
for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
kernel_info_index++;
inputs_type.emplace_back(type_id);
inputs_format.emplace_back(format);
}
dyn_input_idx++;
}
} else if (param_type == "required") {
kernel_info_index++;
inputs_type.emplace_back(type_id);
inputs_format.emplace_back(format);
} else {
if (kernel_info_index < real_input_num) {
MS_LOG(INFO) << "Input type is optional, input index is :" << kernel_info_index;
kernel_info_index++;
inputs_type.emplace_back(type_id);
inputs_format.emplace_back(format);
}
}
}
builder->SetInputsDeviceType(inputs_type);
builder->SetInputsFormat(inputs_format);
}
void SetTidyOutputsInfo(const std::shared_ptr<AnfNode> &anf_node,
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
const std::vector<std::shared_ptr<OpIOInfo>> &outputs) {
std::vector<TypeId> outputs_type;
std::vector<std::string> outputs_format;
auto real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
size_t output_idx = 0;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
if (output_idx >= real_output_num) {
continue;
}
size_t output_num = 0;
if (output->param_type() == "dynamic") {
if (outputs.size() > 1) {
MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
}
output_num = real_output_num;
} else if (output->param_type() == "required") {
output_num = 1;
} else {
if (output_idx < real_output_num) {
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
output_num = 1;
}
}
for (size_t i = 0; i < output_num; i++) {
auto type_id = AnfAlgo::GetOutputInferDataType(anf_node, output_idx);
outputs_type.emplace_back(type_id);
outputs_format.emplace_back(kOpFormat_DEFAULT);
output_idx++;
}
}
builder->SetOutputsDeviceType(outputs_type);
builder->SetOutputsFormat(outputs_format);
}
void GenTidyKernelBuildInfo(const std::shared_ptr<AnfNode> &anf_node,
const std::vector<std::shared_ptr<OpIOInfo>> &inputs,
const std::vector<std::shared_ptr<OpIOInfo>> &outputs) {
auto builder_tmp = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
builder_tmp->SetKernelType(TBE_KERNEL);
SetTidyInputsInfo(anf_node, builder_tmp, inputs);
SetTidyOutputsInfo(anf_node, builder_tmp, outputs);
AnfAlgo::SetSelectKernelBuildInfo(builder_tmp->Build(), anf_node.get());
}
void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr,
const std::shared_ptr<OpInfo> &op_info_new_ptr) {
std::vector<std::shared_ptr<OpIOInfo>> inputs_static = op_info_ptr->inputs_ptr();
std::vector<std::shared_ptr<OpIOInfo>> outputs_static = op_info_ptr->outputs_ptr();
std::vector<std::shared_ptr<OpIOInfo>> inputs_dyn;
std::vector<std::shared_ptr<OpIOInfo>> outputs_dyn;
if ((op_info_ptr->imply_type() == kTBE) && (!mindspore::opt::IsNopNode(kernel_node->cast<AnfNodePtr>()))) {
// 1. create tidy kernelBuildInfo in order to generate json for calling op_select_format
auto anf_node = kernel_node->cast<std::shared_ptr<AnfNode>>();
auto kernel_build_info_ptr = AnfAlgo::GetSelectKernelBuildInfo(anf_node);
GenTidyKernelBuildInfo(kernel_node, inputs_static, outputs_static);
// 2.get dynamic format from op_impl
std::string res_json_str;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->execution_mode() != kPynativeMode) {
res_json_str = OpSelectFormat(kernel_node);
}
if (!res_json_str.empty()) {
(void)ParseDynamicFormatJson(res_json_str, &inputs_dyn, &outputs_dyn);
}
if (inputs_static.size() != inputs_dyn.size()) {
inputs_dyn.clear();
}
if (outputs_static.size() != outputs_dyn.size()) {
outputs_dyn.clear();
}
// 3. resume kernel node's SelectKernelBuildInfo
// As it has been replaced by GenTidyKernelBuildInfo in order to call python func
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_ptr, anf_node.get());
}
// 4.replace by dynamic format and dtype
if (inputs_dyn.empty() && outputs_dyn.empty()) {
MS_LOG(INFO) << "Dynamic select format response is empty, use static register info.";
op_info_new_ptr->set_inputs_ptr(inputs_static);
op_info_new_ptr->set_outputs_ptr(outputs_static);
} else {
MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format.";
for (size_t i = 0; i < inputs_static.size(); i++) {
inputs_dyn[i]->set_param_type(inputs_static[i]->param_type());
inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type());
}
for (size_t j = 0; j < outputs_static.size(); j++) {
outputs_dyn[j]->set_param_type(outputs_static[j]->param_type());
outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type());
}
op_info_new_ptr->set_inputs_ptr(inputs_dyn);
op_info_new_ptr->set_outputs_ptr(outputs_dyn);
}
// 5.copy other opinfo to new op_info_new
op_info_new_ptr->set_op_name(op_info_ptr->op_name());
op_info_new_ptr->set_imply_type(op_info_ptr->imply_type());
op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type());
}
bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(kernel::N);
break;
case 'C':
reshape_type_vec->push_back(kernel::C);
break;
case 'H':
reshape_type_vec->push_back(kernel::H);
break;
case 'W':
reshape_type_vec->push_back(kernel::W);
break;
default:
MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type.";
return false;
}
}
return true;
}
bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
size_t builder_idex, const std::vector<int> &dyn_input_sizes,
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
MS_EXCEPTION_IF_NULL(builder);
std::vector<TypeId> inputs_device_type;
std::vector<std::string> inputs_format;
size_t dyn_input_idx = 0;
size_t kernel_info_index = 0;
MS_EXCEPTION_IF_NULL(inputs[0]);
size_t kernel_info_cnt = inputs[0]->dtypes().size();
std::vector<std::vector<Axis>> reshape_types;
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
std::string param_type = input->param_type();
std::vector<std::string> dtypes = input->dtypes();
std::vector<std::string> formats = input->formats();
if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
MS_LOG(ERROR) << "Set input kernel builder info, dtyps size != formats size.";
return false;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
if (param_type == "dynamic") {
if (dyn_input_sizes.empty()) {
MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
return false;
}
for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
kernel_info_index++;
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
}
dyn_input_idx++;
} else if (param_type == "required") {
kernel_info_index++;
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
} else {
if (kernel_info_index < real_input_num) {
MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index;
kernel_info_index++;
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
}
}
}
builder->SetInputReshapeType(reshape_types);
builder->SetInputsDeviceType(inputs_device_type);
builder->SetInputsFormat(inputs_format);
return true;
}
bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
const size_t &real_output_num,
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
// not now but in the next we need to support dynamic output case
MS_EXCEPTION_IF_NULL(builder);
size_t output_idx = 0;
std::vector<TypeId> outputs_device_type;
std::vector<std::string> outputs_format;
MS_EXCEPTION_IF_NULL(outputs[0]);
size_t kernel_info_cnt = outputs[0]->dtypes().size();
std::vector<std::vector<Axis>> reshape_types;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
if (output_idx >= real_output_num) {
MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!";
continue;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
size_t output_num = 0;
if (output->param_type() == "dynamic") {
if (outputs.size() > 1) {
MS_LOG(EXCEPTION) << "Dynamic output is unsupported multi output!";
}
output_num = real_output_num;
} else if (output->param_type() == "required") {
output_num = 1;
} else {
if (output_idx < real_output_num) {
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is " << output_idx;
output_num = 1;
}
}
for (size_t i = 0; i < output_num; i++) {
std::vector<std::string> dtypes = output->dtypes();
std::vector<std::string> formats = output->formats();
if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
MS_LOG(ERROR) << "Set output kernel builder info, dtyps size != formats size.";
return false;
}
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
outputs_device_type.push_back(type_id);
outputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
output_idx++;
}
}
builder->SetOutputReshapeType(reshape_types);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_device_type);
return true;
}
void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
Processor processor, const std::shared_ptr<const OpInfo> &op_info_ptr) {
MS_EXCEPTION_IF_NULL(builder);
MS_EXCEPTION_IF_NULL(op_info_ptr);
builder->SetProcessor(processor);
std::string fusion_type = op_info_ptr->fusion_type();
if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
builder->SetFusionType(tbe::GetFusionType(fusion_type));
}
builder->SetOpPattern(op_info_ptr->op_pattern());
builder->SetKernelType(TBE_KERNEL);
}
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr,
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
}
if (!inputs.empty()) {
MS_EXCEPTION_IF_NULL(inputs[0]);
size_t kernel_info_cnt = inputs[0]->dtypes().size();
for (size_t j = 0; j < kernel_info_cnt; j++) {
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr);
if (!SetKernelBuilderInputInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
MS_LOG(ERROR) << "Parse kernel metadata, set inputs kernel builder info failed.";
return false;
}
if (!outputs.empty()) {
if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) {
MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed.";
return false;
}
}
kernel_info_list->push_back(builder->Build());
}
} else if (!outputs.empty()) {
MS_EXCEPTION_IF_NULL(outputs[0]);
size_t kernel_info_cnt = outputs[0]->dtypes().size();
for (size_t j = 0; j < kernel_info_cnt; j++) {
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr);
if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) {
MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed.";
return false;
}
kernel_info_list->push_back(builder->Build());
}
}
return true;
}
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
}
if (format == kOpFormat_DEFAULT) {
return true;
}
if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) {
return false;
}
// if shape size is 0, the shape will be a scalar
if (shape.empty()) {
return true;
}
if (shape.size() > kShape4dDims) {
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() < 2) {
return false;
}
return true;
}
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
const size_t kCAxis = 1;
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) {
return false;
}
return false;
}
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) {
return false;
}
if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) {
return false;
}
return false;
}
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
}
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
}
return true;
}
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list;
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
if (op_info_ptr == nullptr) {
return;
}
// dynamic get op format and dtype and replace opinfo
auto op_info_new_ptr = std::make_shared<OpInfo>();
ReplaceByDynamicFormatDtype(kernel_node, op_info_ptr, op_info_new_ptr);
if (!ParseMetadata(kernel_node, op_info_new_ptr, &parse_info_list)) {
MS_LOG(INFO) << "Tbe parsed metadata of op[" << op_name << "] failed.";
return;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
for (const auto &parse_info : parse_info_list) {
if (IsValidKernelInfo(kernel_node, *(parse_info))) {
if (CheckSupported(kernel_node, parse_info)) {
kernel_info_list->push_back(parse_info);
} else {
MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info.";
}
}
if (kernel_info_list->empty()) {
MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "].";
}
}
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -13,20 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
#define MINDSPORE_TBE_KERNEL_SELECT_H
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
struct SupportFormat {
std::vector<std::vector<std::string>> input_format;
std::vector<std::vector<std::string>> output_format;
};
using SupportFormatItem = std::vector<std::string>;
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr char kDynInputKey[] = "dyn_input_sizes";
constexpr size_t kInputIndex_0 = 0;
constexpr size_t kChannelN = 0;
constexpr size_t kChannelC = 1;
constexpr size_t kAlignmented16 = 16;
// 1. all shape no scalar and same
// 2. part scalar : no_scalar (shape size > xxx && alig xxx)
// 3. all no_scalar and not same (broad cast xxx dim)
bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
MS_EXCEPTION_IF_NULL(support_format);
input_num_ = 0;
output_num_ = 0;
input_shapes_.clear();
output_shapes_.clear();
if (AnfAlgo::HasNodeAttr(kDynInputKey, cnode_ptr_)) {
MS_LOG(INFO) << "This broadcast node has dynamic input.";
auto dynamic_size_vec = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynInputKey);
if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) {
MS_LOG(EXCEPTION) << "dynamic attr set error, please check.";
}
auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0);
PadScalarShape(&dynamic_input_shape0_);
input_shapes_.emplace_back(dynamic_input_shape0_);
input_num_ = 1;
} else {
input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_);
for (size_t i = 0; i < input_num_; ++i) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
PadScalarShape(&input_shape);
input_shapes_.emplace_back(input_shape);
}
}
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
for (size_t i = 0; i < output_num_; ++i) {
auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
PadScalarShape(&output);
output_shapes_.emplace_back(output);
}
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_NC1HWC0, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
}
if (shape[kChannelC] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_NC1HWC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return false;
}
}
auto shape_tmp = input_shapes_[0];
auto broadcast_c_axis = std::any_of(
input_shapes_.begin(), input_shapes_.end(),
[&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); });
if (broadcast_c_axis) {
MS_LOG(INFO) << "This node broadcast c channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_NC1HWC0);
}
GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_FRAC_Z, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
}
if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_FRAC_Z);
}
}
} else {
return false;
}
GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (!Is4DShape(shape)) {
return false;
}
if (shape[kChannelN] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is4DShape(shape)) {
return false;
}
}
auto shape_tmp = input_shapes_[0];
auto broadcast_nc_axis =
std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) {
return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN));
});
if (broadcast_nc_axis) {
MS_LOG(INFO) << "This node broadcast n || c channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
}
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_FRAC_NZ, support_format);
return true;
} else {
return false;
}
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_DEFAULT);
} else {
if (shape.size() < kShape2dDims) {
return false;
}
if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) {
return false;
}
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
}
}
} else {
auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(),
[](const std::vector<size_t> &elem) { return elem.size() < kShape2dDims; });
if (less_2dims) {
MS_LOG(INFO) << "This node dim less 2.";
return false;
}
auto shape_tmp = input_shapes_[0];
auto broadcast_last_dim =
std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) {
return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) ||
(shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2));
});
if (broadcast_last_dim) {
MS_LOG(INFO) << "This node broadcast last channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_FRAC_NZ);
}
GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
return false;
}
bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const {
return shape.size() == kShape4dDims;
}
bool TbeKernelBroadCastSelecter::IsSameShape() const {
auto shape = input_shapes_.begin();
for (const auto &item : input_shapes_) {
if (shape->size() != item.size()) {
return false;
}
for (size_t i = 0; i < shape->size(); ++i) {
if (shape->at(i) != item.at(i)) {
return false;
}
}
}
return true;
}
void TbeKernelBroadCastSelecter::PadScalarShape(std::vector<size_t> *shape) const {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {
shape->emplace_back(1);
}
}
bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector<size_t> &shape) const {
return (shape.size() == 1 && shape[0] == 1);
}
bool TbeKernelBroadCastSelecter::HasScalarInput() const {
bool ret = false;
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
ret = true;
break;
}
}
return ret;
}
void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format,
SupportFormatItem *output_support_item) const {
MS_EXCEPTION_IF_NULL(output_support_item);
for (const auto &shape : output_shapes_) {
if (IsScalarShape(shape)) {
output_support_item->emplace_back(kOpFormat_DEFAULT);
} else {
output_support_item->emplace_back(support_format);
}
}
}
void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str,
mindspore::kernel::SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
input_support_format.assign(input_num_, support_format_str);
output_support_format.assign(output_num_, support_format_str);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
class TbeKernelBroadCastSelecter {
public:
explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelBroadCastSelecter() = default;
bool GetShapeInfo(SupportFormat *support_format);
bool IsBroadCastSupport5HD(SupportFormat *support_format) const;
bool IsBroadCastSupportFracZ(SupportFormat *support_format) const;
bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const;
bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const;
bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const;
private:
bool IsSameShape() const;
void PadScalarShape(std::vector<size_t> *shape) const;
bool Is4DShape(const std::vector<size_t> &shape) const;
bool IsScalarShape(const std::vector<size_t> &shape) const;
bool HasScalarInput() const;
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
// broadcast
CNodePtr cnode_ptr_;
size_t input_num_{};
size_t output_num_{};
std::vector<std::vector<size_t>> input_shapes_;
std::vector<std::vector<size_t>> output_shapes_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include <string>
#include <vector>
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr char kKeepDims[] = "keep_dims";
constexpr char kAxis[] = "axis";
constexpr char kTypeInt32[] = "Int32";
constexpr size_t kInputIndex_0 = 0;
constexpr size_t kOutputIndex_0 = 0;
constexpr size_t kChannelN = 0;
constexpr size_t kChannelC = 1;
constexpr size_t kReduceNZMinDim = 3;
bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) {
MS_EXCEPTION_IF_NULL(support_format);
input_shape_.clear();
output_shape_.clear();
axis_.clear();
auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
if (input_num != 1 || output_num != 1) {
MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num
<< ", output num: " << output_num;
}
// get input/output shape
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0);
PadScalarShape(&input_shape_);
output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0);
PadScalarShape(&output_shape_);
// get keep dim attr
GetReduceAttrKeepDim();
// get axis attr
GetReduceAttrAxis();
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (!Is4DShape(input_shape_)) {
return false;
}
if (!keep_dims_ || axis_.empty()) {
return false;
}
auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); });
if (reduce_c_axis) {
return false;
}
AssignSupportFormat(kOpFormat_NC1HWC0, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
// like to 5HD
return false;
}
bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const {
return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format);
}
bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const {
return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format);
}
bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (input_shape_.size() < kReduceNZMinDim) {
return false;
}
if (axis_.empty()) {
return false;
}
auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) {
return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2));
});
if (reduce_last_axis) {
return false;
}
AssignSupportFormat(kOpFormat_FRAC_NZ, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format,
mindspore::kernel::SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
if (!Is4DShape(input_shape_)) {
return false;
}
if (!keep_dims_ || axis_.empty()) {
return false;
}
auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(),
[](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); });
if (reduce_n_c_axis) {
return false;
}
AssignSupportFormat(format, support_format);
return true;
}
void TbeKernelReduceSelecter::GetReduceAttrAxis() {
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
auto axis = primitive->GetAttr(kAxis);
if (axis == nullptr) {
MS_LOG(INFO) << "This node does't have axie attr.";
return;
}
auto type = axis->type();
MS_EXCEPTION_IF_NULL(type);
std::vector<int> axis_list;
if (type->ToString() == kTypeInt32) {
axis_list.emplace_back(GetValue<int>(axis));
} else {
axis_list = GetValue<std::vector<int>>(axis);
}
for (const auto &elem : axis_list) {
if (elem < 0) {
axis_.emplace_back(input_shape_.size() + elem);
} else {
axis_.emplace_back(IntToSize(elem));
}
}
}
void TbeKernelReduceSelecter::GetReduceAttrKeepDim() {
if (!AnfAlgo::HasNodeAttr(kKeepDims, cnode_ptr_)) {
MS_LOG(INFO) << "This node does't have keep_attr.";
keep_dims_ = false;
return;
}
keep_dims_ = AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kKeepDims);
}
void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str,
mindspore::kernel::SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
input_support_format.emplace_back(support_format_str);
output_support_format.emplace_back(support_format_str);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
}
bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; }
void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {
shape->emplace_back(1);
}
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
class TbeKernelReduceSelecter {
public:
explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
~TbeKernelReduceSelecter() = default;
bool GetShapeInfo(SupportFormat *support_format);
bool IsReduceSupport5HD(SupportFormat *support_format) const;
bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const;
bool IsReduceSupportFracZ(SupportFormat *support_format) const;
bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const;
bool IsReduceSupportFracNZ(SupportFormat *support_format) const;
private:
bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const;
void GetReduceAttrAxis();
void GetReduceAttrKeepDim();
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
bool Is4DShape(const std::vector<size_t> &shape) const;
void PadScalarShape(std::vector<size_t> *shape) const;
CNodePtr cnode_ptr_;
std::vector<size_t> input_shape_{};
std::vector<size_t> output_shape_{};
std::vector<size_t> axis_{};
bool keep_dims_ = false;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include <memory>
#include <map>
#include <set>
#include <utility>
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "kernel/tbe/tbe_kernel_build.h"
#include "nlohmann/json.hpp"
#include "utils/context/ms_context.h"
#include "kernel/tbe/tbe_python_funcs.h"
#include "pre_activate/common/helper.h"
#include "kernel/tbe/tbe_convert_utils.h"
#include "parallel/ops_info/ops_utils.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
constexpr auto kName = "name";
constexpr auto kDtype = "dtype";
constexpr auto kFormat = "format";
constexpr auto kPrefixInput = "input";
constexpr auto kPrefixOutput = "output";
constexpr char kDynInputKey[] = "dyn_input_sizes";
constexpr char kParamTypeDynamic[] = "dynamic";
constexpr char kParamTypeRequre[] = "required";
constexpr char kParamTypeOptional[] = "optional";
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list);
tbe_selecter.TbeMetadataInfoEx();
}
TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list)
: cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {}
void TbeKernelSelect::TbeMetadataInfoEx() {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_info_list_);
node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
auto op_info_ptr = OpLib::FindOp(node_name_, kTBE);
if (!op_info_ptr) {
MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_;
return;
}
MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_
<< ", node name: " << cnode_ptr_->fullname_with_scope();
OpPattern pattern = op_info_ptr->op_pattern();
if (pattern == kCommonPattern) {
GetCommonPatternKernelInfo(*op_info_ptr);
} else if (pattern == kDynamicFormatPattern) {
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
} else if (pattern == kFormatAgnosticPattern) {
GetAgnosticPatternKernelInfo(*op_info_ptr);
} else if (pattern == kBroadcastPattern) {
GetBroadcastPatternKernelInfo(*op_info_ptr);
} else if (pattern == kReducePattern) {
GetReducePatternKernelInfo(*op_info_ptr);
} else {
MS_LOG(INFO) << "Warning: op pattern is invailed.";
}
// check support
FilterInVaildKernelInfo();
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select.";
}
void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
// get dynamic inputs
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
std::vector<int> dyn_input_sizes;
if (primitive->HasAttr(kDynInputKey)) {
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr(kDynInputKey));
}
// get real input/output num
size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
const auto inputs_info = op_info.inputs_ptr();
size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
const auto outputs_info = op_info.outputs_ptr();
if (inputs_info.empty() && outputs_info.empty()) {
MS_LOG(EXCEPTION) << "op info input & output is null, please check.";
}
// create kernel build info from opinfo
size_t kernel_build_info_num =
inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size();
for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) {
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
SetTbeBuildCommonInfo(op_info, &builder);
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_device_type;
std::vector<std::vector<Axis>> inputs_reshape_type;
// input
if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes,
&inputs_format, &inputs_device_type, &inputs_reshape_type)) {
break;
}
builder.SetInputsDeviceType(inputs_device_type);
builder.SetInputsFormat(inputs_format);
builder.SetInputReshapeType(inputs_reshape_type);
// output
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_device_type;
std::vector<std::vector<Axis>> outputs_reshape_type;
if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes,
&outputs_format, &outputs_device_type, &outputs_reshape_type)) {
break;
}
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputReshapeType(outputs_reshape_type);
kernel_info_list_->emplace_back(builder.Build());
}
MS_LOG(INFO) << "end.";
}
void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
//
OpInfo op_info_new;
CreateNewOpInfo(op_info, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}
void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
if (op_info.inputs_ptr().size() != 1) {
MS_LOG(EXCEPTION) << "AgnosticPattern only support one input.";
}
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(INFO) << "Got the unknown format " << format;
format = kOpFormat_DEFAULT;
}
SupportFormat support_format;
SupportFormatItem input_item;
SupportFormatItem output_item;
input_item.assign(op_info.inputs_ptr().size(), format);
output_item.assign(op_info.outputs_ptr().size(), format);
support_format.input_format.emplace_back(input_item);
support_format.output_format.emplace_back(output_item);
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}
void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_);
SupportFormat support_format;
broadcast_selecter.GetShapeInfo(&support_format);
if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD.";
}
if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ.";
}
if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0.";
}
if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ.";
}
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}
void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "start.";
auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
SupportFormat support_format;
reduce_selecter.GetShapeInfo(&support_format);
if (!reduce_selecter.IsReduceSupport5HD(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD.";
}
if (reduce_selecter.IsReduceSupportFracZ(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ.";
}
if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0.";
}
if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) {
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ.";
}
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
GetCommonPatternKernelInfo(op_info_new);
MS_LOG(INFO) << "end.";
}
void TbeKernelSelect::FilterInVaildKernelInfo() {
if (kernel_info_list_->empty()) {
MS_LOG(INFO) << "Warning: get kernel build info failed.";
return;
}
auto kernel_build_info_iter = kernel_info_list_->begin();
while (kernel_build_info_iter != kernel_info_list_->end()) {
if (!FilterInVaildShape(kernel_build_info_iter)) {
MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString();
kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter);
continue;
}
if (!TbeCheckSupported(kernel_build_info_iter)) {
MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString();
kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter);
continue;
}
kernel_build_info_iter++;
}
}
bool TbeKernelSelect::FilterInVaildShape(
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
auto format = kernel_build_info_inputs_format.at(i);
if (!IsShapeMatchFormat(shape, format)) {
MS_LOG(INFO) << "The " << i << "th input check failed.";
return false;
}
}
auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
auto format = kernel_build_info_outputs_format.at(j);
if (!IsShapeMatchFormat(shape, format)) {
MS_LOG(INFO) << "The " << j << "th input check failed.";
return false;
}
}
return true;
}
bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
if (format == kOpFormat_DEFAULT) {
return true;
}
static std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
}
// server not support format with C04 suffix
if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
kServerNotSupportFormat.end()) {
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
return false;
}
// not support format:
// 1 NDHWC with shape size != 5
// 2 FRAC_NZ with shape size < 2
// 3 !NDHWC with shape size > 4
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
(format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) ||
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
return false;
}
return true;
}
bool TbeKernelSelect::TbeCheckSupported(
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL,
parallel::BATCHMATMUL,
parallel::TOPK,
parallel::IN_TOPK,
parallel::PACK,
parallel::GATHER_ND,
parallel::UNSORTEF_SEGMENT_MIND,
parallel::UNSORTEF_SEGMENT_PRODD,
parallel::CAST};
auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_);
if (iter == kCheckSupportedOpType.end()) {
return true;
}
MS_LOG(INFO) << "Check support start.";
// replace kernel_info with current kernel info
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
nlohmann::json kernel_json;
TbeKernelJsonCreator creator(CHECK_SUPPORTED);
bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
if (!ret) {
MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed.";
}
ret = TbePythonFuncs::CheckSupported(kernel_json);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
return ret;
}
void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info,
mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) {
MS_EXCEPTION_IF_NULL(builder);
builder->SetProcessor(AICORE);
std::string fusion_type = op_info.fusion_type();
if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
builder->SetFusionType(tbe::GetFusionType(fusion_type));
}
builder->SetOpPattern(op_info.op_pattern());
builder->SetKernelType(TBE_KERNEL);
}
bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
const std::vector<int> &dyn_input_sizes, std::vector<std::string> *formats,
std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) {
MS_EXCEPTION_IF_NULL(formats);
MS_EXCEPTION_IF_NULL(device_types);
MS_EXCEPTION_IF_NULL(reshape_types);
size_t dynamic_input_index = 0;
size_t real_io_tensor_index = 0;
size_t io_info_index = 0;
size_t io_info_num = ios_info.size();
for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) {
std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index];
auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index);
std::string kernel_build_info_format;
if (!io_info_item->formats().empty()) {
kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index);
}
std::string io_param_type = io_info_item->param_type();
std::vector<Axis> reshape_type;
StringToAxisVector(io_info_item->reshape_type(), &reshape_type);
if (io_param_type == kParamTypeDynamic) {
// dynamic io
if (is_input) {
if (dynamic_input_index >= dyn_input_sizes.size()) {
MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index
<< ", dyn_input_sizes size: " << dyn_input_sizes.size();
}
int dynamic_input_size = dyn_input_sizes[dynamic_input_index];
for (int i = 0; i < dynamic_input_size; ++i) {
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
}
dynamic_input_index++;
real_io_tensor_index += dynamic_input_size;
} else {
if (ios_info.size() != 1) {
MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
}
for (size_t i = 0; i < real_io_tensor_num; ++i) {
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
}
real_io_tensor_index += real_io_tensor_num;
}
} else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
// requre or optional io
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
formats->emplace_back(kernel_build_info_format);
reshape_types->emplace_back(reshape_type);
real_io_tensor_index++;
} else {
MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
}
}
if (io_info_index != io_info_num) {
MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num
<< "), this node may has optional input/output.";
}
if (real_io_tensor_index != real_io_tensor_num) {
std::string io_type = is_input ? "inputs " : "outputs";
MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num
<< ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index
<< ") != real_io_tensor_num(" << real_io_tensor_num << ")";
return false;
}
return true;
}
void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
MS_EXCEPTION_IF_NULL(reshape_type_vec);
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(kernel::N);
break;
case 'C':
reshape_type_vec->push_back(kernel::C);
break;
case 'H':
reshape_type_vec->push_back(kernel::H);
break;
case 'W':
reshape_type_vec->push_back(kernel::W);
break;
default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
}
}
}
void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
mindspore::kernel::OpIOInfo *op_io_info_new) {
MS_EXCEPTION_IF_NULL(op_io_info_new);
op_io_info_new->set_index(op_io_info.index());
op_io_info_new->set_name(op_io_info.name());
op_io_info_new->set_param_type(op_io_info.param_type());
op_io_info_new->set_need_compile(op_io_info.need_compile());
op_io_info_new->set_reshape_type(op_io_info.reshape_type());
op_io_info_new->set_shape(op_io_info.shape());
// dtype
std::vector<std::string> dtype_new;
auto dtype = op_io_info.dtypes();
for (size_t i = 0; i < support_format_item.size(); ++i) {
dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end());
}
op_io_info_new->set_dtypes(dtype_new);
// format
std::vector<std::string> format_new;
for (const auto &formats : support_format_item) {
auto format = formats.at(index);
for (size_t j = 0; j < dtype.size(); ++j) {
format_new.emplace_back(format);
}
}
op_io_info_new->set_formats(format_new);
}
std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) {
const std::map<std::string, std::string> kDynamicFormatMap = {
{"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}};
if (op_select_json_item.empty()) {
MS_LOG(EXCEPTION) << "Op select ret item is null.";
}
const char space = ' ';
const char sep = ',';
std::string op_select_tmp = op_select_json_item + ",";
std::vector<std::string> ret;
auto begin = op_select_tmp.find_first_not_of(space, 0);
auto sep_pos = op_select_tmp.find(sep);
while (sep_pos != std::string::npos) {
auto obj = op_select_tmp.substr(begin, sep_pos - begin);
if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
obj = kDynamicFormatMap.at(obj);
}
ret.emplace_back(obj);
begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
sep_pos = op_select_tmp.find(sep, begin);
}
return ret;
}
std::string TbeKernelSelect::OpSelectFormat() {
nlohmann::json kernel_json;
std::string res_json_str;
TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
if (!ret) {
MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed.";
}
res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json);
if (res_json_str.empty()) {
MS_LOG(EXCEPTION) << "op select format error.";
}
MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str;
return res_json_str;
}
void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format,
mindspore::kernel::OpInfo *op_info_new) {
MS_EXCEPTION_IF_NULL(op_info_new);
if (op_info.inputs_ptr().size() != support_format.input_format[0].size() ||
op_info.outputs_ptr().size() != support_format.output_format[0].size()) {
MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size()
<< ", input support size: " << support_format.input_format[0].size()
<< ", op info output size: " << op_info.outputs_ptr().size()
<< ", output support size: " << support_format.output_format[0].size();
}
*op_info_new = op_info;
op_info_new->ClearInputs();
op_info_new->ClearOutputs();
for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
auto input = op_info.inputs_ptr().at(i);
auto input_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get());
op_info_new->add_inputs_ptr(input_new);
}
for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) {
auto output = op_info.outputs_ptr().at(j);
auto output_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get());
op_info_new->add_outputs_ptr(output_new);
}
}
struct SelectOpIOInfo {
std::string name;
std::vector<std::string> dtypes;
std::vector<std::string> formats;
};
void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
mindspore::kernel::OpInfo *op_info_new) {
MS_EXCEPTION_IF_NULL(op_info_new);
auto op_seclect_json = OpSelectFormat();
if (!op_seclect_json.empty()) {
nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json);
if (!json_obj.is_object()) {
MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json;
}
std::vector<SelectOpIOInfo> inputs;
std::vector<SelectOpIOInfo> outputs;
for (const auto &item : json_obj.items()) {
const std::string &item_name = item.key();
bool is_input = (item_name.find(kPrefixInput) != std::string::npos);
bool is_output = (item_name.find(kPrefixOutput) != std::string::npos);
if (!is_input && !is_output) {
MS_LOG(EXCEPTION) << "op select ret json is error.";
}
if (is_input) {
SelectOpIOInfo select_input;
select_input.name = item.value().at(kName);
std::string input_dtype_item = item.value().at(kDtype);
select_input.dtypes = SplitStrToVec(input_dtype_item);
std::string input_format_item = item.value().at(kFormat);
select_input.formats = SplitStrToVec(input_format_item);
inputs.emplace_back(select_input);
} else if (is_output) {
SelectOpIOInfo select_output;
select_output.name = item.value().at(kName);
std::string input_dtype_item = item.value().at(kDtype);
select_output.dtypes = SplitStrToVec(input_dtype_item);
std::string input_format_item = item.value().at(kFormat);
select_output.formats = SplitStrToVec(input_format_item);
outputs.emplace_back(select_output);
}
}
if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) {
MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register.";
}
*op_info_new = op_info;
op_info_new->ClearInputs();
op_info_new->ClearOutputs();
for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
auto input_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get());
op_info_new->add_inputs_ptr(input_new);
}
for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) {
auto output_new = std::make_shared<OpIOInfo>();
CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get());
op_info_new->add_outputs_ptr(output_new);
}
}
}
void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
const std::vector<std::string> &support_dtype,
const std::vector<std::string> &support_format,
mindspore::kernel::OpIOInfo *op_io_info_new) {
MS_EXCEPTION_IF_NULL(op_io_info_new);
op_io_info_new->set_index(op_io_info.index());
op_io_info_new->set_name(op_io_info.name());
op_io_info_new->set_param_type(op_io_info.param_type());
op_io_info_new->set_need_compile(op_io_info.need_compile());
op_io_info_new->set_reshape_type(op_io_info.reshape_type());
op_io_info_new->set_shape(op_io_info.shape());
// dtype
std::vector<std::string> dtype_new;
for (size_t i = 0; i < support_format.size(); ++i) {
dtype_new.insert(dtype_new.end(), support_dtype.begin(), support_dtype.end());
}
op_io_info_new->set_dtypes(dtype_new);
// format
std::vector<std::string> format_new;
for (const auto &format : support_format) {
for (size_t j = 0; j < support_dtype.size(); ++j) {
format_new.emplace_back(format);
}
}
op_io_info_new->set_formats(format_new);
}
void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
if (support_format.input_format.size() != support_format.output_format.size()) {
MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
<< support_format.output_format.size() << ") size not match.";
}
for (size_t i = 0; i < support_format.input_format.size(); ++i) {
auto input_items = support_format.input_format.at(i);
auto output_items = support_format.output_format.at(i);
std::string print_str = "[";
for (const auto &input : input_items) {
print_str.append(input);
print_str.append(", ");
}
print_str.append("] -->");
for (const auto &output : output_items) {
print_str.append(output);
print_str.append(", ");
}
MS_LOG(INFO) << "Support format: " << print_str;
}
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
#define MINDSPORE_TBE_KERNEL_SELECT_H
#include <string>
#include <vector>
#include <memory>
#include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h"
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
namespace mindspore {
namespace kernel {
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
class TbeKernelSelect {
using OpInfoPtr = std::shared_ptr<OpInfo>;
using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator;
public:
TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
~TbeKernelSelect() = default;
void TbeMetadataInfoEx();
private:
void GetCommonPatternKernelInfo(const OpInfo &op_info);
void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info);
void GetAgnosticPatternKernelInfo(const OpInfo &op_info);
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
void GetReducePatternKernelInfo(const OpInfo &op_info);
void FilterInVaildKernelInfo();
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter);
static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);
static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder);
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, const std::vector<int> &dyn_input_sizes,
std::vector<std::string> *formats, std::vector<TypeId> *device_types,
std::vector<std::vector<Axis>> *reshape_types);
static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new);
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info,
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
OpIOInfo *op_io_info_new);
// op select(dynamic)
void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new);
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector<std::string> &support_dtype,
const std::vector<std::string> &support_format, OpIOInfo *op_io_info_new);
static std::vector<std::string> SplitStrToVec(const std::string &op_select_json_item);
std::string OpSelectFormat();
static void PrintSupportedFormat(const SupportFormat &support_format);
private:
CNodePtr cnode_ptr_;
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_;
std::string node_name_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
......@@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg";
constexpr char BATCH_MATMUL[] = "BatchMatMul";
constexpr char EXPAND_DIMS[] = "ExpandDims";
constexpr char SQUARE[] = "Square";
constexpr char BATCHMATMUL[] = "BatchMatMul";
constexpr char TOPK[] = "TopK";
constexpr char IN_TOPK[] = "InTopK";
constexpr char PACK[] = "Pack";
constexpr char GATHER_ND[] = "GatherNd";
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
// Parallel don't care
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
......
......@@ -21,7 +21,6 @@
#include <vector>
#include "device/ascend/kernel_select_ascend.h"
#include "kernel/kernel_query.h"
#include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
......
......@@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
return nullptr;
}
auto node_name = AnfAlgo::GetCNodeName(node);
if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) {
if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) {
return nullptr;
}
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
......
......@@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \
.op_pattern("formatAgnostic") \
.input(0, "x", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.get_op_info()
......
......@@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \
.compute_cost(10) \
.kernel_name("abs_grad") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "y", None, "required", None) \
.input(1, "dy", None, "required", None) \
.output(0, "z", False, "required", "all") \
......
......@@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
......
......@@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \
.attr("n", "required", "int", "all") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.get_op_info()
......
......@@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
.input(1, "x2", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
......
......@@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \
.input(0, "x", False, "required", "all") \
.input(1, "bias", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
......
......@@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.input(0, "x", False, "required", "all", reshape_type="NC") \
.output(0, "sum", False, "required", "all") \
.output(1, "square_sum", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.input(5, "batch_mean", False, "required", "all") \
.input(6, "batch_variance", False, "required", "all") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
......
......@@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.input(3, "batch_variance", False, "required", "all") \
.output(0, "diff_scale", False, "required", "all") \
.output(1, "diff_offset", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
......
......@@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.output(1, "batch_mean", False, "required", "all") \
.output(2, "batch_variance", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
......
......@@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \
.attr("dst_type", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.BOOL_None, DataType.F16_None) \
.dtype_format(DataType.BOOL_None, DataType.U8_None) \
.dtype_format(DataType.BOOL_None, DataType.F32_None) \
.dtype_format(DataType.BOOL_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.F32_None) \
.dtype_format(DataType.I8_None, DataType.I32_None) \
.dtype_format(DataType.U8_None, DataType.F16_None) \
.dtype_format(DataType.U8_None, DataType.F32_None) \
.dtype_format(DataType.U8_None, DataType.I32_None) \
.dtype_format(DataType.I32_None, DataType.BOOL_None) \
.dtype_format(DataType.I32_None, DataType.F16_None) \
.dtype_format(DataType.I32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F32_None) \
.dtype_format(DataType.F16_None, DataType.I32_None) \
.dtype_format(DataType.F32_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.I32_None) \
.get_op_info()
......
......@@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \
.attr("axis", "required", "int", "all") \
.input(0, "input_values", False, "dynamic", "all") \
.output(0, "output_data", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
......
......@@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.compute_cost(10) \
.kernel_name("conv2d") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
......@@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default,
DataType.F16_5HD) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \
.get_op_info()
......
......@@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
.input(1, "mask", False, "required", "all") \
.input(2, "keep_prob", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
......
......@@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
......
......@@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \
.output(0, "y", False, "required", "all") \
.output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop")
.input(3, "mean", False, "required", "all") \
.output(0, "pd_gamma", False, "required", "all") \
.output(1, "pd_beta", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
.input(3, "mean", False, "required", "all") \
.input(4, "gamma", False, "required", "all") \
.output(0, "pd_x", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
......
......@@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
......
......@@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.op_pattern("reduce") \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
......
......@@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \
.kernel_name("relu_grad_v2") \
.partial_flag(True) \
.input(0, "gradients", False, "required", "all") \
.input(1, "mask", False, "rerequired", "all") \
.input(1, "mask", False, "required", "all") \
.output(0, "backprops", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \
......
......@@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
......
......@@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \
.input(0, "x", None, "required", None) \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.get_op_info()
......
......@@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD,
......
......@@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \
.attr("output_num", "required", "int", "all") \
.input(0, "value", False, "required", "all") \
.output(0, "output", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
......
......@@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
......
......@@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
.input(0, "x", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
......
......@@ -97,6 +97,7 @@ class RegOp:
"""
if not isinstance(value, str):
raise TypeError("%s value must be str" % str(value))
return True
def _is_int(self, value):
"""
......@@ -110,6 +111,7 @@ class RegOp:
"""
if not isinstance(value, int):
raise TypeError("%s value must be int" % str(value))
return True
def _is_bool(self, value):
"""
......@@ -123,6 +125,7 @@ class RegOp:
"""
if not isinstance(value, bool):
raise TypeError("%s value must be bool" % str(value))
return True
def _check_param(self, param_list, key_list, fn_list, kwargs):
"""
......@@ -494,6 +497,7 @@ class DataType:
The current list below maybe not completed. If necessary, please add it.
"""
None_None = ("", "")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册