diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 9c795fbe56c521c9bd3c303d741a66e413a15e58..400e152335edc2e9c273ec8539c4bbea5c1e95cc 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -16,39 +16,74 @@ from ..ops.special import Const from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply -def dtype_promotion(raw_inputs): - def add_dtype(i): - if type(i) == int: - return np.array(i, dtype=np.int32) - if type(i) == float: - return np.array(i, dtype=np.float32) - if type(i) == bool: - return np.array(i, dtype=np.bool_) - return None - - scalar_inputs = [ - add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i) - ] - inputs = [i for i in raw_inputs if hasattr(i, "dtype")] - assert len(scalar_inputs + inputs) > 0 - dtype = None - if len(inputs) > 0: - dtype = np.result_type(*inputs) - dtype_all = np.result_type(*(inputs + scalar_inputs)) - assert ( - dtype != np.float64 and dtype != np.int64 - ), "unsupport dtype {} by dtype_promotion, please use explict type convert".format( - dtype - ) - if dtype_all == np.bool_: - for i in raw_inputs: - if not hasattr(i, "dtype") or i.dtype != np.bool_: - raise TypeError( - "bool dtype can not be operated with an element without bool dtype" - ) - if dtype_all == np.float64: - dtype_all = np.float32 - return dtype_all +def dtype_promotion(inputs): + """ + Returns the dtype that would result from performing an arithmetic + operation on the provided input tensors and scalars. + """ + # map numpy.dtype.kind to priority + category_priority = { + "f": 3, # floating-point + "i": 2, # signed integer + "u": 2, # unsigned integer + "b": 1, # boolean + } + + def scalar2dtype(x): + """ + For scalar `x`, returns its corresponding type. A floating point scalar + has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. + A boolean scalar has dtype 'bool'. + """ + if isinstance(x, bool): + return np.bool_ + if isinstance(x, int): + return np.int32 + if isinstance(x, float): + return np.float32 + + def promote_types(types, cat): + """ + Returns the data type with sufficient size to hold all types of + category `cat` in the list `types`. + """ + used_types = [ + i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat + ] + assert len(used_types) > 0 + res = used_types[0] + for i in used_types: + res = np.promote_types(res, i) + return res + + def max_priority(types): + """ + Returns the maximum value of the priority of each type in the list + `types`. + """ + if not types: + return 0 + else: + return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) + + scalars = [] + tensors = [] + + for data in inputs: + if hasattr(data, "dtype"): + tensors.append(data.dtype) + elif isinstance(data, (float, int, bool)): + scalars.append(scalar2dtype(data)) + + max_pri_scalars = max_priority(scalars) + max_pri_tensors = max_priority(tensors) + + assert max_pri_scalars > 0 or max_pri_tensors > 0 + + if max_pri_scalars > max_pri_tensors: + return promote_types(scalars, max_pri_scalars) + else: + return promote_types(tensors, max_pri_tensors) def get_device(inputs):