提交 2f4a75e7 编写于 作者: M Megvii Engine Team

feat(mge/utils): redesign dtype promotion

GitOrigin-RevId: 4f2fe1b6ce1430e96cb8bac34ffbbc46548007f5
上级 2e9ba679
......@@ -16,39 +16,74 @@ from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
def dtype_promotion(raw_inputs):
def add_dtype(i):
if type(i) == int:
return np.array(i, dtype=np.int32)
if type(i) == float:
return np.array(i, dtype=np.float32)
if type(i) == bool:
return np.array(i, dtype=np.bool_)
return None
scalar_inputs = [
add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i)
]
inputs = [i for i in raw_inputs if hasattr(i, "dtype")]
assert len(scalar_inputs + inputs) > 0
dtype = None
if len(inputs) > 0:
dtype = np.result_type(*inputs)
dtype_all = np.result_type(*(inputs + scalar_inputs))
assert (
dtype != np.float64 and dtype != np.int64
), "unsupport dtype {} by dtype_promotion, please use explict type convert".format(
dtype
)
if dtype_all == np.bool_:
for i in raw_inputs:
if not hasattr(i, "dtype") or i.dtype != np.bool_:
raise TypeError(
"bool dtype can not be operated with an element without bool dtype"
)
if dtype_all == np.float64:
dtype_all = np.float32
return dtype_all
def dtype_promotion(inputs):
"""
Returns the dtype that would result from performing an arithmetic
operation on the provided input tensors and scalars.
"""
# map numpy.dtype.kind to priority
category_priority = {
"f": 3, # floating-point
"i": 2, # signed integer
"u": 2, # unsigned integer
"b": 1, # boolean
}
def scalar2dtype(x):
"""
For scalar `x`, returns its corresponding type. A floating point scalar
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'.
A boolean scalar has dtype 'bool'.
"""
if isinstance(x, bool):
return np.bool_
if isinstance(x, int):
return np.int32
if isinstance(x, float):
return np.float32
def promote_types(types, cat):
"""
Returns the data type with sufficient size to hold all types of
category `cat` in the list `types`.
"""
used_types = [
i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat
]
assert len(used_types) > 0
res = used_types[0]
for i in used_types:
res = np.promote_types(res, i)
return res
def max_priority(types):
"""
Returns the maximum value of the priority of each type in the list
`types`.
"""
if not types:
return 0
else:
return max([category_priority.get(np.dtype(i).kind, 0) for i in types])
scalars = []
tensors = []
for data in inputs:
if hasattr(data, "dtype"):
tensors.append(data.dtype)
elif isinstance(data, (float, int, bool)):
scalars.append(scalar2dtype(data))
max_pri_scalars = max_priority(scalars)
max_pri_tensors = max_priority(tensors)
assert max_pri_scalars > 0 or max_pri_tensors > 0
if max_pri_scalars > max_pri_tensors:
return promote_types(scalars, max_pri_scalars)
else:
return promote_types(tensors, max_pri_tensors)
def get_device(inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册