提交 82b0f677 编写于 作者: M Megvii Engine Team

fix(mge/core): fix dtype promotion issue for quantized dtype

GitOrigin-RevId: 9d09c8fa6f355b3c84ba4b9ec7872c1a21393726
上级 8118a594
...@@ -14,7 +14,7 @@ import numpy as np ...@@ -14,7 +14,7 @@ import numpy as np
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .dtype import is_equal from .dtype import is_equal, is_quantize
def dtype_promotion(inputs): def dtype_promotion(inputs):
...@@ -122,7 +122,7 @@ def convert_single_value(v, inputs, *, dtype=None, device=None): ...@@ -122,7 +122,7 @@ def convert_single_value(v, inputs, *, dtype=None, device=None):
tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))]
assert len(tensors) > 0 assert len(tensors) > 0
if isinstance(v, (TensorWrapperBase, TensorBase)): if isinstance(v, (TensorWrapperBase, TensorBase)):
v = astype(v, dtype) v = astype(v, v.dtype if is_quantize(v.dtype) else dtype)
else: else:
(v,) = Const(v, dtype=dtype, device=device)(*tensors) (v,) = Const(v, dtype=dtype, device=device)(*tensors)
return v return v
......
...@@ -12,7 +12,6 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union ...@@ -12,7 +12,6 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np import numpy as np
from ..core.tensor.dtype import is_quantize
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
...@@ -529,11 +528,7 @@ class Module(metaclass=ABCMeta): ...@@ -529,11 +528,7 @@ class Module(metaclass=ABCMeta):
), "param `{}` shape mismatch, should be {}, get {}".format( ), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape k, var.shape, to_be_load.shape
) )
# For quantized dtype, the initialized dtype var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var = var.astype(to_be_load.dtype)
var._reset(to_be_load)
loaded.append(k) loaded.append(k)
return set(loaded), set(skipped) return set(loaded), set(skipped)
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import tensor
from megengine.core.tensor import dtype
from megengine.functional.elemwise import _elwise from megengine.functional.elemwise import _elwise
...@@ -150,3 +151,18 @@ def test_logical_oprs(): ...@@ -150,3 +151,18 @@ def test_logical_oprs():
np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy()) np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy())
np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy()) np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy())
np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy()) np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy())
def test_qadd():
inp_scale = 0.5
outp_scale = 0.2
x = np.arange(6).reshape(2, 3).astype("float32")
y = np.arange(6).reshape(2, 3).astype("float32")
x = tensor(x, dtype=dtype.qint8(inp_scale))
y = tensor(y, dtype=dtype.qint8(inp_scale))
result_mge = F.elemwise._elemwise_multi_type(
x, y, mode="QADD", dtype=dtype.qint8(outp_scale)
)
result_mge = result_mge.astype("float32").numpy()
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy()
np.testing.assert_almost_equal(result_mge, result_expect, decimal=6)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册