diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index 85c22bb7a7258b170cd90c700363e18bce6170f6..89a84a5a0954c6d5bcd5d27f72f9d8dfe717a5e4 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -62,6 +62,21 @@ def get_zero_point(dtype): return metadata["zero_point"] +def is_equal(dt0, dt1): + def _get_zero_point(dtype): + assert is_quantize(dtype) + metadata = dtype.metadata["mgb_dtype"] + return metadata.get("zero_point") + + if is_quantize(dt0) and is_quantize(dt1): + return get_scale(dt0) == get_scale(dt1) and _get_zero_point( + dt0 + ) == _get_zero_point(dt1) + if not (is_quantize(dt0) or is_quantize(dt1)): + return dt0 == dt1 + return False + + def _check_zero_point(zp: int, dtype_str: str): qmin = _metadata_dict[dtype_str].qmin qmax = _metadata_dict[dtype_str].qmax diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 400e152335edc2e9c273ec8539c4bbea5c1e95cc..adb1bd1a9ca6befe7dff06a207623ec7a63b452e 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -14,6 +14,7 @@ import numpy as np from ..ops import builtin from ..ops.special import Const from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply +from .dtype import is_equal def dtype_promotion(inputs): @@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None): def astype(x, dtype): dtype = np.dtype(dtype) - if x.dtype != dtype: + if not is_equal(x.dtype, dtype): (x,) = apply(builtin.TypeCvt(param=dtype), x) return x diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 2f1a590db128a838e087dcf02e6d41ed2d5be3b0..8dff1633050807c1942e23e6c74f5432884517ca 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.tensor_wrapper import TensorWrapper @@ -71,3 +72,17 @@ def test_transpose(): x = np.random.rand(2, 5).astype("float32") xx = TensorWrapper(x) np.testing.assert_almost_equal(xx.T.numpy(), x.T) + + +def test_as_type(): + x = TensorWrapper([1, 2, 3], dtype=np.float32) + y = x.astype(qint8(0.1)) + np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) + z = y.astype(qint8(0.2)) + np.testing.assert_almost_equal(get_scale(z.dtype), 0.2) + a = z.astype(quint8(0.3, 127)) + np.testing.assert_almost_equal(get_scale(a.dtype), 0.3) + np.testing.assert_equal(get_zero_point(a.dtype), 127) + b = a.astype(quint8(0.3, 128)) + np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) + np.testing.assert_equal(get_zero_point(b.dtype), 128)