From 82b0f67770cb43c73bac0b05dafa550ffaf60c71 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 16 Nov 2020 11:41:11 +0800 Subject: [PATCH] fix(mge/core): fix dtype promotion issue for quantized dtype GitOrigin-RevId: 9d09c8fa6f355b3c84ba4b9ec7872c1a21393726 --- imperative/python/megengine/core/tensor/utils.py | 4 ++-- imperative/python/megengine/module/module.py | 7 +------ .../python/test/unit/functional/test_elemwise.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index adb1bd1a9..82b5b5042 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -14,7 +14,7 @@ import numpy as np from ..ops import builtin from ..ops.special import Const from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply -from .dtype import is_equal +from .dtype import is_equal, is_quantize def dtype_promotion(inputs): @@ -122,7 +122,7 @@ def convert_single_value(v, inputs, *, dtype=None, device=None): tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] assert len(tensors) > 0 if isinstance(v, (TensorWrapperBase, TensorBase)): - v = astype(v, dtype) + v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) else: (v,) = Const(v, dtype=dtype, device=device)(*tensors) return v diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 2af53c910..c23367297 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -12,7 +12,6 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union import numpy as np -from ..core.tensor.dtype import is_quantize from ..core.tensor.utils import make_shape_tuple from ..logger import get_logger from ..tensor import Parameter, Tensor @@ -529,11 +528,7 @@ class Module(metaclass=ABCMeta): ), "param `{}` shape mismatch, should be {}, get {}".format( k, var.shape, to_be_load.shape ) - # For quantized dtype, the initialized dtype - # scale/zero_points maybe invalid, use pretrained dtype instead. - if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): - var = var.astype(to_be_load.dtype) - var._reset(to_be_load) + var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) loaded.append(k) return set(loaded), set(skipped) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 3436f145b..a089d0066 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -10,6 +10,7 @@ import numpy as np import megengine.functional as F from megengine import tensor +from megengine.core.tensor import dtype from megengine.functional.elemwise import _elwise @@ -150,3 +151,18 @@ def test_logical_oprs(): np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) + + +def test_qadd(): + inp_scale = 0.5 + outp_scale = 0.2 + x = np.arange(6).reshape(2, 3).astype("float32") + y = np.arange(6).reshape(2, 3).astype("float32") + x = tensor(x, dtype=dtype.qint8(inp_scale)) + y = tensor(y, dtype=dtype.qint8(inp_scale)) + result_mge = F.elemwise._elemwise_multi_type( + x, y, mode="QADD", dtype=dtype.qint8(outp_scale) + ) + result_mge = result_mge.astype("float32").numpy() + result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() + np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) -- GitLab