From 51ca769d098888634986ae4965cfa4fb85998221 Mon Sep 17 00:00:00 2001 From: zjun Date: Wed, 1 Apr 2020 19:41:21 +0800 Subject: [PATCH] add new mode for operator info register --- mindspore/ccsrc/kernel/oplib/opinfo.h | 9 + mindspore/ccsrc/kernel/oplib/oplib.cc | 61 ++- mindspore/ccsrc/kernel/oplib/oplib.h | 4 +- mindspore/ops/__init__.py | 4 +- .../_op_impl/tbe/adam_apply_one_with_decay.py | 231 ++------- mindspore/ops/op_info_register.py | 441 +++++++++++++++++- 6 files changed, 535 insertions(+), 215 deletions(-) diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 7861da34d..56abea926 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -61,6 +61,7 @@ class OpIOInfo { std::string name() const { return name_; } bool need_compile() const { return need_compile_; } std::string param_type() const { return param_type_; } + std::string reshape_type() const { return reshape_type_; } std::string shape() const { return shape_; } std::vector dtypes() const { return dtypes_; } std::vector formats() const { return formats_; } @@ -69,6 +70,7 @@ class OpIOInfo { void set_name(const std::string& name) { name_ = name; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } void set_param_type(const std::string& param_type) { param_type_ = param_type; } + void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } void set_shape(const std::string& shape) { shape_ = shape; } void set_dtypes(const std::vector& dtype) { dtypes_ = dtype; } void set_formats(const std::vector& formats) { formats_ = formats; } @@ -78,6 +80,7 @@ class OpIOInfo { std::string name_; bool need_compile_ = false; std::string param_type_; + std::string reshape_type_; std::string shape_; std::vector dtypes_; std::vector formats_; @@ -96,6 +99,8 @@ class OpInfo { int compute_cost() const { return compute_cost_; } std::string kernel_name() const { return kernel_name_; } bool partial_flag() const { return partial_flag_; } + bool dynamic_format() const { return dynamic_format_; } + std::string op_pattern() const { return op_pattern_; } std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } @@ -110,6 +115,8 @@ class OpInfo { void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } + void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } + void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } void add_attrs_ptr(const std::shared_ptr& attr) { attrs_ptr_.push_back(attr); } void add_inputs_ptr(const std::shared_ptr& input) { inputs_ptr_.push_back(input); } void add_outputs_ptr(const std::shared_ptr& output) { outputs_ptr_.push_back(output); } @@ -129,6 +136,8 @@ class OpInfo { int compute_cost_ = 0; std::string kernel_name_; bool partial_flag_ = false; + bool dynamic_format_ = false; + std::string op_pattern_; std::vector> attrs_ptr_; std::vector> inputs_ptr_; std::vector> outputs_ptr_; diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index b20bd741f..4059b8e24 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -26,18 +26,22 @@ namespace mindspore { namespace kernel { constexpr auto kImplyType = "imply_type"; constexpr auto kOpName = "op_name"; -constexpr auto kTbe = "TBE"; -constexpr auto kAkg = "akg"; -constexpr auto kAutodiff = "AutoDiff"; constexpr auto kFusionType = "fusion_type"; constexpr auto kAsyncFlag = "async_flag"; constexpr auto kBinfileName = "binfile_name"; constexpr auto kComputeCost = "compute_cost"; constexpr auto kKernelName = "kernel_name"; constexpr auto kPartialFlag = "partial_flag"; +constexpr auto kReshapeType = "reshape_type"; +constexpr auto kOpPattern = "op_pattern"; +constexpr auto kDynamicFormat = "dynamic_format"; +constexpr auto kDtypeFormat = "dtype_format"; constexpr auto kAttr = "attr"; constexpr auto kIputs = "inputs"; constexpr auto kOutputs = "outputs"; +constexpr auto kTbe = "TBE"; +constexpr auto kAkg = "akg"; +constexpr auto kAutodiff = "AutoDiff"; constexpr auto kName = "name"; constexpr auto kParamType = "param_type"; constexpr auto kDtype = "dtype"; @@ -89,8 +93,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI std::shared_ptr op_info = std::make_shared(); MS_EXCEPTION_IF_NULL(op_info); op_info->set_op_name(obj.at(kOpName)); - op_info->set_imply_type(imply_type); op_info->set_impl_path(impl_path); + op_info->set_imply_type(imply_type); op_info->set_fusion_type(obj.at(kFusionType)); if (imply_type == kTBE) { op_info->set_async_flag(obj.at(kAsyncFlag)); @@ -98,6 +102,12 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI op_info->set_compute_cost(obj.at(kComputeCost)); op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_partial_flag(obj.at(kPartialFlag)); + if (obj.find(kOpPattern) != obj.end()) { + op_info->set_op_pattern(obj.at(kOpPattern)); + } + if (obj.find(kDynamicFormat) != obj.end()) { + op_info->set_dynamic_format(obj.at(kDynamicFormat)); + } } auto attrs = obj.at(kAttr); for (const auto& attr : attrs) { @@ -106,16 +116,20 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI return false; } } + nlohmann::json dtype_format; + if (obj.find(kDtypeFormat) != obj.end()) { + dtype_format = obj.at(kDtypeFormat); + } auto inputs = obj.at(kIputs); for (const auto& input : inputs) { - if (!DecodeInputOutput(input, imply_type, kInput, op_info)) { + if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { MS_LOG(DEBUG) << "DecodeInputOutput Failed"; return false; } } auto outputs = obj.at(kOutputs); for (const auto& output : outputs) { - if (!DecodeInputOutput(output, imply_type, kOutput, op_info)) { + if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { MS_LOG(DEBUG) << "DecodeInputOutput Failed"; return false; } @@ -156,16 +170,42 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, return ret; } +bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, + size_t index) { + bool ret = true; + try { + std::vector dtype; + std::vector format; + for (const auto& it : dtype_format) { + dtype.emplace_back(it[index][0]); + format.emplace_back(it[index][1]); + } + op_io->set_dtypes(dtype); + op_io->set_formats(format); + } catch (const std::exception& e) { + MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); + ret = false; + } + return ret; +} + bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info) { + const std::shared_ptr& op_info, const nlohmann::json& dtype_format) { bool ret = true; try { std::shared_ptr op_io = std::make_shared(); MS_EXCEPTION_IF_NULL(op_io); op_io->set_index(obj.at(kIndex)); op_io->set_name(obj.at(kName)); - op_io->set_dtypes(obj.at(kDtype)); - op_io->set_formats(obj.at(kFormat)); + if (!dtype_format.empty()) { + if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) { + MS_LOG(ERROR) << "Decode dtype format failed"; + return false; + } + } else { + op_io->set_dtypes(obj.at(kDtype)); + op_io->set_formats(obj.at(kFormat)); + } if (op_io->dtypes().size() != op_io->formats().size()) { MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes() << "is not equal to format size:" << op_io->formats(); @@ -181,6 +221,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply if (obj.find(kShape) != obj.end()) { op_io->set_shape(obj.at(kShape)); } + if (obj.find(kReshapeType) != obj.end()) { + op_io->set_reshape_type(obj.at(kReshapeType)); + } } if (io_type == kInput) { diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h index 37c3fdcfc..a4c5e04bb 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ b/mindspore/ccsrc/kernel/oplib/oplib.h @@ -38,8 +38,10 @@ class OpLib { static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, const std::shared_ptr& op_info); + static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, + size_t index); static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info); + const std::shared_ptr& op_info, const nlohmann::json& dtype_format); static bool GetRefInfo(const std::shared_ptr& op_info); static bool CheckRepetition(const std::shared_ptr& op_info); }; diff --git a/mindspore/ops/__init__.py b/mindspore/ops/__init__.py index 23109b386..6f4f68067 100644 --- a/mindspore/ops/__init__.py +++ b/mindspore/ops/__init__.py @@ -30,7 +30,7 @@ Note: from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry -from .op_info_register import op_info_register +from .op_info_register import op_info_register, TBERegOp, DataType from .primitive import constexpr from .._c_expression import signature_rw, signature_kind @@ -40,6 +40,6 @@ __primitive__ = [ ] __all__ = ["get_vm_impl_fn", "vm_impl_registry", - "op_info_register", + "op_info_register", "TBERegOp", "DataType", "constexpr"] __all__.extend(__primitive__) diff --git a/mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py b/mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py index d1c43ca95..a8911e81b 100644 --- a/mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +++ b/mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py @@ -14,208 +14,41 @@ # ============================================================================ """AdamApplyOneWithDecay op""" -from mindspore.ops.op_info_register import op_info_register +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +adam_apply_one_with_decay_op_info = TBERegOp("AdamApplyOneWithDecay") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("adam_apply_one_with_decay.so") \ + .compute_cost(10) \ + .kernel_name("adam_apply_one_with_decay") \ + .partial_flag(True) \ + .input(0, "input0", False, "required", "all") \ + .input(1, "input1", False, "required", "all") \ + .input(2, "input2", False, "required", "all") \ + .input(3, "input3", False, "required", "all") \ + .input(4, "input4", False, "required", "all") \ + .input(5, "mul0_x", False, "required", "all") \ + .input(6, "mul1_x", False, "required", "all") \ + .input(7, "mul2_x", False, "required", "all") \ + .input(8, "mul3_x", False, "required", "all") \ + .input(9, "mul4_x", False, "required", "all") \ + .input(10, "add2_y", False, "required", "all") \ + .output(0, "output0", False, "required", "all") \ + .output(1, "output1", False, "required", "all") \ + .output(2, "output2", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() -@op_info_register("""{ - "op_name": "AdamApplyOneWithDecay", - "imply_type": "TBE", - "fusion_type": "OPAQUE", - "async_flag": false, - "binfile_name": "adam_apply_one_with_decay.so", - "compute_cost": 10, - "kernel_name": "adam_apply_one_with_decay", - "partial_flag": true, - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "input0", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "input1", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "input2", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 3, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "input3", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 4, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "input4", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 5, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "mul0_x", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 6, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "mul1_x", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 7, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "mul2_x", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 8, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "mul3_x", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 9, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "mul4_x", - "need_compile": false, - "param_type": "required", - "shape": "all" - }, - { - "index": 10, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "add2_y", - "need_compile": false, - "param_type": "required", - "shape": "all" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "output0", - "need_compile": true, - "param_type": "required", - "shape": "all" - }, - { - "index": 1, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "output1", - "need_compile": true, - "param_type": "required", - "shape": "all" - }, - { - "index": 2, - "dtype": [ - "float16", "float" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "output2", - "need_compile": true, - "param_type": "required", - "shape": "all" - } - ] -}""") +@op_info_register(adam_apply_one_with_decay_op_info) def _adam_apply_one_with_decay_tbe(): """AdamApplyOneWithDecay TBE register""" return diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 80f40ff1d..6a42099c8 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -16,6 +16,7 @@ """Operators info register.""" import os +import json import inspect from mindspore._c_expression import Oplib from mindspore._checkparam import ParamValidator as validator @@ -32,21 +33,453 @@ def op_info_register(op_info): 'op_info' must be a str of json format represent the op info, the op info will be added into oplib. Args: - op_info (str): op info of json format. + op_info (str or dict): op info of json format. Returns: Function, returns a decorator for op info register. """ def register_decorator(func): - validator.check_type("op_info", op_info, [str]) + if isinstance(op_info, dict): + op_info_real = json.dumps(op_info) + else: + op_info_real = op_info + validator.check_type("op_info", op_info_real, [str]) op_lib = Oplib() file_path = os.path.realpath(inspect.getfile(func)) # keep the path custom ops implementation. imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path - if not op_lib.reg_op(op_info, imply_path): - raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info)) + if not op_lib.reg_op(op_info_real, imply_path): + raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real)) def wrapped_function(*args, **kwargs): return func(*args, **kwargs) return wrapped_function return register_decorator + + +class RegOp(): + """ + Base class for op info register. + + Args: + op_name (str): Name of op. + inputs (list): Inputs inoformation of the op. + outputs (list): Outputs information of the op. + attr_ (list): Attribute information of the op. + dtype_format_ (list): Dtype and format information of the op. + """ + + def __init__(self, op_name=""): + if not isinstance(op_name, str): + raise ValueError("op name value must be string") + if not op_name.strip(): + raise ValueError("op name is empty") + self.op_name = op_name + self.inputs = [] + self.outputs = [] + self.attr_ = [] + self.dtype_format_ = [] + + def is_string(self, value): + """ + Check if the value is a str type. + + Args: + value: Parameter to to check. + + Raises: + TypeError: If the type of value is not a str. + """ + if not isinstance(value, str): + raise TypeError("%s value must be str" % str(value)) + + def is_int(self, value): + """ + Check if the value is a int. + + Args: + value: Parameter to to check. + + Raises: + TypeError: If the type of value is not a int. + """ + if not isinstance(value, int): + raise TypeError("%s value must be int" % str(value)) + + def is_bool(self, value): + """ + Check if the value is a bool. + + Args: + value: Parameter to to check. + + Raises: + TypeError: If the type of value is not a bool. + """ + if not isinstance(value, bool): + raise TypeError("%s value must be bool" % str(value)) + + def dtype_format(self, *args): + """ + Register dtype and format. + + Args: + args (tuple): Value of dtype and format. + + Raises: + ValueError: If the size of args not equal to input size add output size. + TypeError: If the type of args is not tuple. + """ + if len(self.inputs) + len(self.outputs) != len(args): + raise ValueError("input size add output size must be equal to detype format size") + dtype_format = [] + for arg in args: + if not isinstance(arg, tuple) or len(arg) != 2: + raise ValueError("dtype and format value must be tuple of two elements") + self.is_string(arg[0]) + self.is_string(arg[1]) + dtype_format.append(arg) + self.dtype_format_.append(tuple(dtype_format)) + return self + + def get_op_info(self): + """ + Return all registration information for this instance. + + The '_' character ending the key is removed here for compatibility with previous version. + + Key will be unified into an underlined form later. + """ + op_info = {} + for key, value in self.__dict__.items(): + if isinstance(key, str) and key.endswith('_'): + op_info[key.rstrip('_')] = value + else: + op_info[key] = value + return op_info + + +class TBERegOp(RegOp): + """Class for TBE op info register.""" + + def __init__(self, op_name=""): + super(TBERegOp, self).__init__(op_name) + self.imply_type = "TBE" + self.fusion_type_ = '' + self.async_flag_ = False + self.binfile_name_ = '' + self.compute_cost_ = 10 + self.kernel_name_ = '' + self.partial_flag_ = False + self.reshape_type_ = '' + self.dynamic_format_ = False + self.op_pattern_ = "" + + def fusion_type(self, fusion_type): + """ + Register fusion type. + + Args: + fusion_type (str): Value of fusion type. + """ + self.is_string(fusion_type) + self.fusion_type_ = fusion_type + return self + + def async_flag(self, async_flag): + """ + Register async flag. + + Args: + async_flag (bool): Value of async flag. + """ + self.is_bool(async_flag) + self.async_flag_ = async_flag + return self + + def binfile_name(self, binfile_name): + """ + Register binfile name. + + Args: + binfile_name (str): Name of op binfile. + """ + self.is_string(binfile_name) + self.binfile_name_ = binfile_name + return self + + def compute_cost(self, compute_cost): + """ + Register compute cost. + + Args: + compute_cost (int): Value of compute cost. + """ + self.is_int(compute_cost) + self.compute_cost_ = compute_cost + return self + + def kernel_name(self, kernel_name): + """ + Register kernel name. + + Args: + kernel_name (str): Name of op kernel. + """ + self.is_string(kernel_name) + self.kernel_name_ = kernel_name + return self + + def partial_flag(self, partial_flag): + """ + Register partial flag. + + Args: + partial_flag (bool): Value of partial flag. + """ + self.is_bool(partial_flag) + self.partial_flag_ = partial_flag + return self + + def reshape_type(self, reshape_type): + """ + Register reshape type. + + Args: + reshape_type (str): Value of reshape type. + """ + self.is_string(reshape_type) + self.reshape_type_ = reshape_type + return self + + def dynamic_format(self, dynamic_format): + """ + Register dynamic format. + + Args: + reshape_type (bool): Value of dynamic format. + """ + self.is_bool(dynamic_format) + self.dynamic_format_ = dynamic_format + return self + + def op_pattern(self, pattern=None): + """ + Register op pattern information. + + Args: + pattern (str): Value of op pattern. + """ + if pattern is not None and self.istring(pattern): + self.op_pattern_ = pattern + return self + + def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs): + """ + Register op attribute information. + + Args: + name (str): Name of the attribute. Default: None. + param_type (str): Param type of the attribute. Default: None. + type (str): Type of the attribute. Default: None. + value (str): Value of the attribute. Default: None. + default_value (str): Default value of attribute. Default: None. + kwargs (dict): Other information for the attribute. + """ + param_list = [name, param_type, value_type, value, default_value] + attr_dict = {} + for index, element in enumerate(param_list): + if element is not None: + self.is_string(element) + if index == 0: + attr_dict["name"] = element + elif index == 1: + attr_dict["param_type"] = element + elif index == 2: + attr_dict["type"] = element + elif index == 3: + attr_dict["value"] = element + elif index == 4: + attr_dict["default_value"] = element + if kwargs: + attr_dict = dict(attr_dict, **kwargs) + self.attr_.append(attr_dict) + return self + + def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): + """ + Register op input information. + + Args: + index (int): Order of the input. Default: None. + name (str): Name of the input. Default: None. + need_compile (bool): The input need compile whether or not. Default: None. + param_type (str): Type of the input. Default: None. + shape (str): Shape of the input. Default: None. + kwargs (dict): Other information for the input. + """ + param_list = [index, name, need_compile, param_type, shape] + input_dict = {} + for idx, element in enumerate(param_list): + if element is not None: + if idx == 0: + self.is_int(element) + input_dict["index"] = element + elif idx == 1: + self.is_string(element) + input_dict["name"] = element + elif idx == 2: + self.is_bool(element) + input_dict["need_compile"] = element + elif idx == 3: + self.is_string(element) + input_dict["param_type"] = element + elif idx == 4: + self.is_string(element) + input_dict["shape"] = element + if kwargs: + input_dict = dict(input_dict, **kwargs) + self.inputs.append(input_dict) + return self + + def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): + """ + Register op output information. + + Args: + index (int): Order of the output. Default: None. + name (str): Name of the output. Default: None. + need_compile (bool): The output need compile whether or not. Default: None. + param_type (str): Type of the output. Default: None. + shape (str): Shape of the output. Default: None. + kwargs (dict): Other information for the output. + """ + param_list = [index, name, need_compile, param_type, shape] + output_dict = {} + for idx, element in enumerate(param_list): + if element is not None: + if idx == 0: + self.is_int(element) + output_dict["index"] = element + elif idx == 1: + self.is_string(element) + output_dict["name"] = element + elif idx == 2: + self.is_bool(element) + output_dict["need_compile"] = element + elif idx == 3: + self.is_string(element) + output_dict["param_type"] = element + elif idx == 4: + self.is_string(element) + output_dict["shape"] = element + if kwargs: + output_dict = dict(output_dict, **kwargs) + self.outputs.append(output_dict) + return self + +class DataType(): + """ + Various combinations of dtype and formatself. + + The current list below maybe not completed. If necessary, please add it. + """ + + BOOL_None = ("bool", "") + BOOL_Default = ("bool", "DefaultFormat") + BOOL_5HD = ("bool", "NC1HWC0") + BOOL_NCHW = ("bool", "NCHW") + BOOL_NHWC = ("bool", "NHWC") + BOOL_HWCN = ("bool", "HWCN") + + I8_None = ("int8", "") + I8_Default = ("int8", "DefaultFormat") + I8_5HD = ("int8", "NC1HWC0") + I8_FracZ = ("int8", "Fracz") + I8_FracNZ = ("int8", "FRACTAL_NZ") + I8_NCHW = ("int8", "NCHW") + I8_NHWC = ("int8", "NHWC") + I8_HWCN = ("int8", "HWCN") + + U8_None = ("uint8", "") + U8_Default = ("uint8", "DefaultFormat") + U8_5HD = ("uint8", "NC1HWC0") + U8_FracZ = ("uint8", "Fracz") + U8_FracNZ = ("uint8", "FRACTAL_NZ") + U8_NCHW = ("uint8", "NCHW") + U8_NHWC = ("uint8", "NHWC") + U8_HWCN = ("uint8", "HWCN") + + I16_None = ("int16", "") + I16_Default = ("int16", "DefaultFormat") + I16_5HD = ("int16", "NC1HWC0") + I16_FracZ = ("int16", "Fracz") + I16_FracNZ = ("int16", "FRACTAL_NZ") + I16_NCHW = ("int16", "NCHW") + I16_NHWC = ("int16", "NHWC") + I16_HWCN = ("int16", "HWCN") + + U16_None = ("uint16", "") + U16_Default = ("uint16", "DefaultFormat") + U16_5HD = ("uint16", "NC1HWC0") + U16_FracZ = ("uint16", "Fracz") + U16_FracNZ = ("uint16", "FRACTAL_NZ") + U16_NCHW = ("uint16", "NCHW") + U16_NHWC = ("uint16", "NHWC") + U16_HWCN = ("uint16", "HWCN") + + I32_None = ("int32", "") + I32_Default = ("int32", "DefaultFormat") + I32_5HD = ("int32", "NC1HWC0") + I32_FracZ = ("int32", "Fracz") + I32_FracNZ = ("int32", "FRACTAL_NZ") + I32_NCHW = ("int32", "NCHW") + I32_NHWC = ("int32", "NHWC") + I32_HWCN = ("int32", "HWCN") + + U32_None = ("uint32", "") + U32_Default = ("uint32", "DefaultFormat") + U32_5HD = ("uint32", "NC1HWC0") + U32_FracZ = ("uint32", "Fracz") + U32_FracNZ = ("uint32", "FRACTAL_NZ") + U32_NCHW = ("uint32", "NCHW") + U32_NHWC = ("uint32", "NHWC") + U32_HWCN = ("uint32", "HWCN") + + I64_None = ("int64", "") + I64_Default = ("int64", "DefaultFormat") + I64_5HD = ("int64", "NC1HWC0") + I64_FracZ = ("int64", "Fracz") + I64_FracNZ = ("int64", "FRACTAL_NZ") + I64_NCHW = ("int64", "NCHW") + I64_NHWC = ("int64", "NHWC") + I64_HWCN = ("int64", "HWCN") + + U64_None = ("uint64", "") + U64_Default = ("uint64", "DefaultFormat") + U64_5HD = ("uint64", "NC1HWC0") + U64_FracZ = ("uint64", "Fracz") + U64_FracNZ = ("uint64", "FRACTAL_NZ") + U64_NCHW = ("uint64", "NCHW") + U64_NHWC = ("uint64", "NHWC") + U64_HWCN = ("uint64", "HWCN") + + F16_None = ("float16", "") + F16_Default = ("float16", "DefaultFormat") + F16_5HD = ("float16", "NC1HWC0") + F16_FracZ = ("float16", "Fracz") + F16_FracNZ = ("float16", "FRACTAL_NZ") + F16_C1HWNCoC0 = ("float16", "C1HWNCoC0") + F16_NCHW = ("float16", "NCHW") + F16_NHWC = ("float16", "NHWC") + F16_HWCN = ("float16", "HWCN") + + F32_None = ("float32", "") + F32_Default = ("float32", "DefaultFormat") + F32_5HD = ("float32", "NC1HWC0") + F32_FracZ = ("float32", "Fracz") + F32_FracNZ = ("float32", "FRACTAL_NZ") + F32_C1HWNCoC0 = ("float32", "C1HWNCoC0") + F32_NCHW = ("float32", "NCHW") + F32_NHWC = ("float32", "NHWC") + F32_HWCN = ("float32", "HWCN") -- GitLab