From 8cf4b1c23476bb843c6f2e8e425b559b28f5ed7a Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 11 Aug 2023 15:00:48 +0800 Subject: [PATCH] [NewIR]Polish IR code (#56087) * perfect code * delete __all__ --- .../ir/dialect/op_generator/op_build_gen.py | 4 +- .../fluid/ir/dialect/op_generator/op_gen.py | 6 +- paddle/fluid/ir/dialect/pd_api.cc | 16 +- paddle/fluid/ir/dialect/pd_api.h | 8 +- paddle/fluid/ir/dialect/utils.cc | 171 ++++++++++++++++++ paddle/fluid/ir/dialect/utils.h | 150 +-------------- paddle/fluid/pybind/ir.cc | 146 +++++++++++++-- python/paddle/fluid/framework.py | 28 +-- python/paddle/ir/__init__.py | 11 +- python/paddle/ir/core.py | 18 ++ 10 files changed, 346 insertions(+), 212 deletions(-) create mode 100644 paddle/fluid/ir/dialect/utils.cc diff --git a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py index 5c3696d02c8..f4d91d5c068 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py @@ -65,7 +65,7 @@ def GenBuildInputArgsStr( ] if ( op_attribute_build_arg_type_list[attr_idx] - != "std::string" + != "const std::string&" ): if ( default_value[0] == "'" @@ -106,7 +106,7 @@ def GenBuildInputArgsStr( op_non_mutable_attribute_build_arg_type_list[ attr_idx ] - != "std::string" + != "const std::string&" ): if ( default_value[0] == "'" diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index ba4424aa7bd..d990141add5 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -258,12 +258,12 @@ class OpInfoParser: 'ir::ArrayAttribute', 'const std::vecot&', ], - 'str': ['ir::StrAttribute', 'std::string'], + 'str': ['ir::StrAttribute', 'const std::string&'], 'str[]': [ 'ir::ArrayAttribute', 'const std::vector&', ], - 'Place': ['paddle::dialect::PlaceAttribute', 'Place'], + 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], 'DataLayout': [ 'paddle::dialect::DataLayoutAttribute', 'DataLayout', @@ -577,7 +577,7 @@ class OpInfoParser: temp_type = attribute_info['data_type'] if 'IntArray' in temp_type: if 'data_type' in attribute_info: - temp_type = attribute_info['data_type'] + temp_type = "const " + attribute_info['data_type'] + "&" type_list.append(self.get_phi_dtype_name(temp_type)) return type_list diff --git a/paddle/fluid/ir/dialect/pd_api.cc b/paddle/fluid/ir/dialect/pd_api.cc index f65b1e25f9c..6405f7dce7e 100644 --- a/paddle/fluid/ir/dialect/pd_api.cc +++ b/paddle/fluid/ir/dialect/pd_api.cc @@ -29,7 +29,9 @@ ir::OpResult add_n(std::vector x) { return add_n_op.out(); } -ir::OpResult mean(ir::OpResult x, std::vector axis, bool keepdim) { +ir::OpResult mean(ir::OpResult x, + const std::vector& axis, + bool keepdim) { paddle::dialect::MeanOp mean_op = APIBuilder::Instance().GetBuilder()->Build( x, axis, keepdim); @@ -37,27 +39,27 @@ ir::OpResult mean(ir::OpResult x, std::vector axis, bool keepdim) { } ir::OpResult sum(ir::OpResult x, - std::vector axis, + const std::vector& axis, phi::DataType dtype, bool keepdim) { - paddle::dialect::SumOp sum_op = + auto sum_op = APIBuilder::Instance().GetBuilder()->Build( x, axis, dtype, keepdim); return sum_op.out(); } ir::OpResult divide(ir::OpResult x, ir::OpResult y) { - paddle::dialect::DivideOp divide_op = + auto divide_op = APIBuilder::Instance().GetBuilder()->Build(x, y); return divide_op.out(); } -ir::OpResult full(std::vector shape, +ir::OpResult full(const std::vector& shape, float value, phi::DataType dtype, - phi::Place place) { - paddle::dialect::FullOp full_op = + const phi::Place& place) { + auto full_op = APIBuilder::Instance().GetBuilder()->Build( shape, value, dtype, place); return full_op.out(); diff --git a/paddle/fluid/ir/dialect/pd_api.h b/paddle/fluid/ir/dialect/pd_api.h index 5d3b2376314..9581e0a4e7e 100644 --- a/paddle/fluid/ir/dialect/pd_api.h +++ b/paddle/fluid/ir/dialect/pd_api.h @@ -25,20 +25,20 @@ namespace dialect { ir::OpResult add_n(std::vector x); ir::OpResult mean(ir::OpResult x, - std::vector axis = {}, + const std::vector& axis = {}, bool keepdim = false); ir::OpResult sum(ir::OpResult x, - std::vector axis = {}, + const std::vector& axis = {}, phi::DataType dtype = phi::DataType::UNDEFINED, bool keepdim = false); ir::OpResult divide(ir::OpResult x, ir::OpResult y); -ir::OpResult full(std::vector shape, +ir::OpResult full(const std::vector& shape, float value, phi::DataType dtype = phi::DataType::FLOAT32, - phi::Place place = phi::CPUPlace()); + const phi::Place& place = phi::CPUPlace()); ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out); diff --git a/paddle/fluid/ir/dialect/utils.cc b/paddle/fluid/ir/dialect/utils.cc new file mode 100644 index 00000000000..cd6ff35ef7f --- /dev/null +++ b/paddle/fluid/ir/dialect/utils.cc @@ -0,0 +1,171 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/ir/dialect/utils.h" + +namespace paddle { +namespace dialect { + +enum class AttrType { + UNDEFINED = 0, + BOOL, + INT32, + INT64, + + FLOAT, + DOUBLE, + + ARRAY, + INT_ARRAY, + + SCALAR, + DATA_TYPE, + DATA_LAYOUT, + PLACE, + + STRING, + + NUM_ATTR_TYPES, +}; + +static inline AttrType GetAttributeType(const ir::Attribute& attr) { + if (attr.isa()) { + return AttrType::BOOL; + } else if (attr.isa()) { + return AttrType::FLOAT; + } else if (attr.isa()) { + return AttrType::DOUBLE; + } else if (attr.isa()) { + return AttrType::INT32; + } else if (attr.isa()) { + return AttrType::INT64; + } else if (attr.isa()) { + return AttrType::ARRAY; + } else if (attr.isa()) { + return AttrType::STRING; + } else if (attr.isa()) { + return AttrType::INT_ARRAY; + } else if (attr.isa()) { + return AttrType::DATA_TYPE; + } else if (attr.isa()) { + return AttrType::PLACE; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported ir Attribute type when casting it into " + "AttrType.")); + } +} + +static std::unordered_map> + kAttrCastMap = { + {AttrType::BOOL, + [](const ir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; + }}, + {AttrType::FLOAT, + [](const ir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; + }}, + {AttrType::DOUBLE, + [](const ir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; + }}, + {AttrType::INT32, + [](const ir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; + }}, + {AttrType::INT64, + [](const ir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; + }}, + {AttrType::INT_ARRAY, + [](const ir::Attribute& attr) { + return VariantType{ + attr.dyn_cast() + .data() + .GetData()}; + }}, + {AttrType::STRING, + [](const ir::Attribute& attr) { + return VariantType{attr.dyn_cast().AsString()}; + }}, + {AttrType::DATA_TYPE, + [](const ir::Attribute& attr) { + return VariantType{ + attr.dyn_cast().data()}; + }}, + {AttrType::PLACE, + [](const ir::Attribute& attr) { + return VariantType{ + attr.dyn_cast().data()}; + }}, + {AttrType::ARRAY, + [](const ir::Attribute& attr) { + auto attr_vec = attr.dyn_cast().AsVector(); + if (attr_vec.size() == 0) { + return VariantType{std::vector()}; + } + AttrType element_type = GetAttributeType(attr_vec[0]); + + if (element_type == AttrType::BOOL) { + std::vector vec_bools; + for (auto vec_element : attr_vec) { + vec_bools.push_back( + vec_element.dyn_cast().data()); + } + return VariantType{vec_bools}; + } else if (element_type == AttrType::INT32) { + std::vector vec_int32; + for (auto vec_element : attr_vec) { + vec_int32.push_back( + vec_element.dyn_cast().data()); + } + return VariantType{vec_int32}; + } else if (element_type == AttrType::INT64) { + std::vector vec_int64; + for (auto vec_element : attr_vec) { + vec_int64.push_back( + vec_element.dyn_cast().data()); + } + return VariantType{vec_int64}; + } else if (element_type == AttrType::FLOAT) { + std::vector vec_float; + for (auto vec_element : attr_vec) { + vec_float.push_back( + vec_element.dyn_cast().data()); + } + return VariantType{vec_float}; + } else if (element_type == AttrType::DOUBLE) { + std::vector vec_double; + for (auto vec_element : attr_vec) { + vec_double.push_back( + vec_element.dyn_cast().data()); + } + return VariantType{vec_double}; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported ir Attribute type when casting it into " + "vector.")); + } + }}, +}; + +VariantType GetAttributeData(const ir::Attribute& attr) { + AttrType attr_type = GetAttributeType(attr); + return kAttrCastMap[attr_type](attr); +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index 3dbeaa9cc5a..a81febc0cba 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -141,155 +141,7 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar, } } -enum class AttrType { - UNDEFINED = 0, - BOOL, - INT32, - INT64, - - FLOAT, - DOUBLE, - - ARRAY, - INT_ARRAY, - - SCALAR, - DATA_TYPE, - DATA_LAYOUT, - PLACE, - - STRING, - - NUM_ATTR_TYPES, -}; - -static inline AttrType GetAttributeType(const ir::Attribute& attr) { - if (attr.isa()) { - return AttrType::BOOL; - } else if (attr.isa()) { - return AttrType::FLOAT; - } else if (attr.isa()) { - return AttrType::DOUBLE; - } else if (attr.isa()) { - return AttrType::INT32; - } else if (attr.isa()) { - return AttrType::INT64; - } else if (attr.isa()) { - return AttrType::ARRAY; - } else if (attr.isa()) { - return AttrType::STRING; - } else if (attr.isa()) { - return AttrType::INT_ARRAY; - } else if (attr.isa()) { - return AttrType::DATA_TYPE; - } else if (attr.isa()) { - return AttrType::PLACE; - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported ir Attribute type when casting it into " - "AttrType.")); - } -} - -static std::unordered_map> - attr_cast_map = { - {AttrType::BOOL, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::FLOAT, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::DOUBLE, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::INT32, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::INT64, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::INT_ARRAY, - [](const ir::Attribute& attr) { - return VariantType{ - attr.dyn_cast() - .data() - .GetData()}; - }}, - {AttrType::STRING, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().AsString()}; - }}, - {AttrType::DATA_TYPE, - [](const ir::Attribute& attr) { - return VariantType{ - attr.dyn_cast().data()}; - }}, - {AttrType::PLACE, - [](const ir::Attribute& attr) { - return VariantType{ - attr.dyn_cast().data()}; - }}, - {AttrType::ARRAY, - [](const ir::Attribute& attr) { - auto attr_vec = attr.dyn_cast().AsVector(); - if (attr_vec.size() == 0) { - return VariantType{std::vector()}; - } - AttrType element_type = GetAttributeType(attr_vec[0]); - - if (element_type == AttrType::BOOL) { - std::vector vec_bools; - for (auto vec_element : attr_vec) { - vec_bools.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_bools}; - } else if (element_type == AttrType::INT32) { - std::vector vec_int32; - for (auto vec_element : attr_vec) { - vec_int32.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_int32}; - } else if (element_type == AttrType::INT64) { - std::vector vec_int64; - for (auto vec_element : attr_vec) { - vec_int64.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_int64}; - } else if (element_type == AttrType::FLOAT) { - std::vector vec_float; - for (auto vec_element : attr_vec) { - vec_float.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_float}; - } else if (element_type == AttrType::DOUBLE) { - std::vector vec_double; - for (auto vec_element : attr_vec) { - vec_double.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_double}; - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported ir Attribute type when casting it into " - "vector.")); - } - }}, -}; - -static inline VariantType GetAttributeData(const ir::Attribute& attr) { - AttrType attr_type = GetAttributeType(attr); - return attr_cast_map[attr_type](attr); -} +VariantType GetAttributeData(const ir::Attribute& attr); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 2081d327a23..a6da23bc78e 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -57,7 +57,46 @@ PyTypeObject *g_ir_opresult_pytype = nullptr; void BindOpsAPI(pybind11::module *module); void BindProgram(py::module *m) { - py::class_ program(*m, "Program"); + py::class_ program(*m, "Program", R"DOC( + Create Python Program. Program is an abstraction of model structure, divided into + computational graphs and weights. The Program has a main block that stores the computational + graphs. + + A set of Program usually contains startup program and main program. + A startup program is set to contain some initial work, eg. initialize the ``Parameter``, and the main + program will contain the network structure and vars for train. + + A set of Program can be used for test or train, in train program , + Paddle will contain all content to build a train network, in test + program Paddle will prune some content which is irrelevant to test, eg. + backward ops and vars. + + **Notes**: + **we have** :ref:`api_paddle_static_default_startup_program` **and** :ref:`api_paddle_static_default_main_program` + **by default, a pair of them will shared the parameters. The** :ref:`api_paddle_static_default_startup_program` **only run once to initialize parameters,** + :ref:`api_paddle_static_default_main_program` **run in every mini batch and adjust the weights.** + + Returns: + Program: An empty Program. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + main_program = static.Program() + startup_program = static.Program() + with static.program_guard(main_program=main_program, startup_program=startup_program): + x = static.data(name="x", shape=[-1, 784], dtype='float32') + y = static.data(name="y", shape=[-1, 1], dtype='int32') + z = static.nn.fc(name="fc", x=x, size=10, activation="relu") + + print("main program is: {}".format(main_program)) + print("start up program is: {}".format(startup_program)) + )DOC"); program .def( "__init__", @@ -78,7 +117,13 @@ void BindProgram(py::module *m) { } void BindBlock(py::module *m) { - py::class_ block(*m, "Block"); + py::class_ block(*m, "Block", R"DOC( + In IR, a Block has a list of Operation and can represent a sub computational graph. + + Notes: + The constructor of Block should not be invoked directly. You can + use `Program.block()` to get a block. + )DOC"); block.def("front", &Block::front, return_value_policy::reference) .def("get_parent_program", [](Block &self) { return self.GetParentOp()->GetParentProgram(); }) @@ -91,14 +136,35 @@ void BindBlock(py::module *m) { } return op_list; }) - .def("remove_op", [](Block &self, Operation *op) { - auto op_iter = std::find(self.begin(), self.end(), op); - self.erase(op_iter); - }); + .def( + "remove_op", + [](Block &self, Operation *op) { + auto op_iter = std::find(self.begin(), self.end(), op); + self.erase(op_iter); + }, + R"DOC( + Remove the specific position operator. + + Args: + index(int): the position that the operator to insert. + + Returns: + None + + )DOC"); } void BindOperation(py::module *m) { - py::class_ op(*m, "Operation"); + py::class_ op(*m, "Operation", R"DOC( + In IR, all the operation are represented by Operation, and Operation + is regarded as a build in an instruction of a Block. Users can call + python api to describe their neural network. + + Notes: + The constructor of operator should not be invoked directly. Use + python api, for example: paddle.mean for building mean operation. + + )DOC"); op.def("name", &Operation::name) .def("get_parent_block", py::overload_cast<>(&Operation::GetParent), @@ -170,7 +236,15 @@ void BindOperation(py::module *m) { } void BindValue(py::module *m) { - py::class_ value(*m, "Value"); + py::class_ value(*m, "Value", R"DOC( + Value class represents the SSA value in the IR system. It is a directed edge + and a base class. + + Notes: + The constructor of Value should not be invoked directly. Value can be automatically constructed + when build network. + + )DOC"); value .def("get_defining_op", &Value::GetDefiningOp, @@ -185,7 +259,16 @@ void BindValue(py::module *m) { } void BindOpOperand(py::module *m) { - py::class_ op_operand(*m, "OpOperand"); + py::class_ op_operand(*m, + "OpOperand", + R"DOC( + OpOperand class represents the op_operand (input) of operation. + + Notes: + The constructor of OpOperand should not be invoked directly. OpOperand can be automatically constructed + when build network. + + )DOC"); op_operand .def("source", [](OpOperand &self) { return self.source().dyn_cast(); }) @@ -228,7 +311,13 @@ void SetStopGradient(const OpResult &self, bool stop_gradient) { } void BindOpResult(py::module *m) { - py::class_ op_result(*m, "OpResult"); + py::class_ op_result(*m, "OpResult", R"DOC( + OpResult class represents the value(output) defined by a result of operation. + + Notes: + The constructor of OpResult should not be invoked directly. OpResult can be automatically constructed + when build network. + )DOC"); g_ir_opresult_pytype = reinterpret_cast(op_result.ptr()); op_result.def("__eq__", &OpResult::operator==) .def("__eq__", @@ -301,7 +390,42 @@ void BindUtils(pybind11::module *m) { []() { APIBuilder::Instance().ResetInsertionPointToStart(); }); m->def("reset_insertion_point_to_end", []() { APIBuilder::Instance().ResetInsertionPointToEnd(); }); - m->def("translate_to_new_ir", &paddle::TranslateLegacyProgramToProgram); + m->def("translate_to_new_ir", &paddle::TranslateLegacyProgramToProgram, R"DOC( + Convert Fluid Program to New IR Program. + + Args: + + legacy_program (ProgramDesc): The Fluid Program that will be converted. + + Returns: + Program: The New IR Program + + Raises: + PreconditionNotMet: If legacy_program has multi block will raise error. + + Examples: + .. code-block:: python + + import paddle + from paddle import ir + paddle.enable_static() + + x = paddle.randn([4, 4]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + x_s.stop_gradient = False + y_s = paddle.matmul(x_s, x_s) + z_s = paddle.add(y_s, y_s) + k_s = paddle.tanh(z_s) + newir_program = ir.translate_to_new_ir(main_program.desc) + + print(newir_program) + + )DOC"); } void BindNewIR(pybind11::module *module) { diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 05731baad24..b375cca76c1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1016,32 +1016,8 @@ def convert_np_dtype_to_dtype_(np_dtype): dtype = np.dtype(np_dtype) if ir.core._use_new_ir_api(): - if dtype == np.float32: - return core.DataType.FLOAT32 - elif dtype == np.float64: - return core.DataType.FLOAT64 - elif dtype == np.float16: - return core.DataType.FLOAT16 - elif dtype == np.int32: - return core.DataType.INT32 - elif dtype == np.int16: - return core.DataType.INT16 - elif dtype == np.int64: - return core.DataType.INT64 - elif dtype == np.bool_: - return core.DataType.BOOL - elif dtype == np.uint16: - # since there is still no support for bfloat16 in NumPy, - # uint16 is used for casting bfloat16 - return core.DataType.UINT16 - elif dtype == np.uint8: - return core.DataType.UINT8 - elif dtype == np.int8: - return core.DataType.INT8 - elif dtype == np.complex64: - return core.DataType.COMPLEX64 - elif dtype == np.complex128: - return core.DataType.COMPLEX128 + if dtype in ir.core.np_type_to_paddle_type.keys(): + return ir.core.np_type_to_paddle_type[dtype] else: raise ValueError("Not supported numpy dtype %s" % dtype) else: diff --git a/python/paddle/ir/__init__.py b/python/paddle/ir/__init__.py index 0d272cf88ae..f023cfc0539 100644 --- a/python/paddle/ir/__init__.py +++ b/python/paddle/ir/__init__.py @@ -31,13 +31,4 @@ from paddle.fluid.libpaddle.ir import ( from . import core -__all__ = [ # noqa - 'Program', - 'Block', - 'Operation', - 'Value', - 'OpOperand', - 'OpResult', - 'Type', - 'translate_to_new_ir', -] +__all__ = [] diff --git a/python/paddle/ir/core.py b/python/paddle/ir/core.py index 9310c9b75bf..ea73d266cca 100644 --- a/python/paddle/ir/core.py +++ b/python/paddle/ir/core.py @@ -13,11 +13,29 @@ # limitations under the License. +import numpy as np + import paddle +from paddle.fluid.libpaddle import DataType from paddle.fluid.libpaddle.ir import Program, set_global_program from ..fluid.wrapped_decorator import signature_safe_contextmanager +np_type_to_paddle_type = { + np.dtype("float32"): DataType.FLOAT32, + np.dtype("float64"): DataType.FLOAT64, + np.dtype("float16"): DataType.FLOAT16, + np.dtype("int32"): DataType.INT32, + np.dtype("int16"): DataType.INT16, + np.dtype("int64"): DataType.INT64, + np.dtype("bool_"): DataType.BOOL, + np.dtype("uint16"): DataType.UINT16, + np.dtype("uint8"): DataType.UINT8, + np.dtype("int8"): DataType.INT8, + np.dtype("complex64"): DataType.COMPLEX64, + np.dtype("complex128"): DataType.COMPLEX128, +} + def _use_new_ir_api(): """ -- GitLab