From be511a56f99c45c75cdb05958c8a287453673a9e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 3 Nov 2020 11:52:12 +0800 Subject: [PATCH] fix(mge/imperative): fix tensor astype failed for quantized type GitOrigin-RevId: 383458acbf18fa956ca1ccaa376255ff1b06735a --- imperative/python/megengine/core/tensor/dtype.py | 15 +++++++++++++++ imperative/python/megengine/core/tensor/utils.py | 3 ++- .../python/test/unit/core/test_tensor_wrapper.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index 85c22bb7a..89a84a5a0 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -62,6 +62,21 @@ def get_zero_point(dtype): return metadata["zero_point"] +def is_equal(dt0, dt1): + def _get_zero_point(dtype): + assert is_quantize(dtype) + metadata = dtype.metadata["mgb_dtype"] + return metadata.get("zero_point") + + if is_quantize(dt0) and is_quantize(dt1): + return get_scale(dt0) == get_scale(dt1) and _get_zero_point( + dt0 + ) == _get_zero_point(dt1) + if not (is_quantize(dt0) or is_quantize(dt1)): + return dt0 == dt1 + return False + + def _check_zero_point(zp: int, dtype_str: str): qmin = _metadata_dict[dtype_str].qmin qmax = _metadata_dict[dtype_str].qmax diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 400e15233..adb1bd1a9 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -14,6 +14,7 @@ import numpy as np from ..ops import builtin from ..ops.special import Const from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply +from .dtype import is_equal def dtype_promotion(inputs): @@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None): def astype(x, dtype): dtype = np.dtype(dtype) - if x.dtype != dtype: + if not is_equal(x.dtype, dtype): (x,) = apply(builtin.TypeCvt(param=dtype), x) return x diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 2f1a590db..8dff16330 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.tensor_wrapper import TensorWrapper @@ -71,3 +72,17 @@ def test_transpose(): x = np.random.rand(2, 5).astype("float32") xx = TensorWrapper(x) np.testing.assert_almost_equal(xx.T.numpy(), x.T) + + +def test_as_type(): + x = TensorWrapper([1, 2, 3], dtype=np.float32) + y = x.astype(qint8(0.1)) + np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) + z = y.astype(qint8(0.2)) + np.testing.assert_almost_equal(get_scale(z.dtype), 0.2) + a = z.astype(quint8(0.3, 127)) + np.testing.assert_almost_equal(get_scale(a.dtype), 0.3) + np.testing.assert_equal(get_zero_point(a.dtype), 127) + b = a.astype(quint8(0.3, 128)) + np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) + np.testing.assert_equal(get_zero_point(b.dtype), 128) -- GitLab