提交 be511a56 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix tensor astype failed for quantized type

GitOrigin-RevId: 383458acbf18fa956ca1ccaa376255ff1b06735a
上级 b309890c
...@@ -62,6 +62,21 @@ def get_zero_point(dtype): ...@@ -62,6 +62,21 @@ def get_zero_point(dtype):
return metadata["zero_point"] return metadata["zero_point"]
def is_equal(dt0, dt1):
def _get_zero_point(dtype):
assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"]
return metadata.get("zero_point")
if is_quantize(dt0) and is_quantize(dt1):
return get_scale(dt0) == get_scale(dt1) and _get_zero_point(
dt0
) == _get_zero_point(dt1)
if not (is_quantize(dt0) or is_quantize(dt1)):
return dt0 == dt1
return False
def _check_zero_point(zp: int, dtype_str: str): def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str].qmin qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax qmax = _metadata_dict[dtype_str].qmax
......
...@@ -14,6 +14,7 @@ import numpy as np ...@@ -14,6 +14,7 @@ import numpy as np
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .dtype import is_equal
def dtype_promotion(inputs): def dtype_promotion(inputs):
...@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None): ...@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if x.dtype != dtype: if not is_equal(x.dtype, dtype):
(x,) = apply(builtin.TypeCvt(param=dtype), x) (x,) = apply(builtin.TypeCvt(param=dtype), x)
return x return x
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.core.tensor.tensor_wrapper import TensorWrapper from megengine.core.tensor.tensor_wrapper import TensorWrapper
...@@ -71,3 +72,17 @@ def test_transpose(): ...@@ -71,3 +72,17 @@ def test_transpose():
x = np.random.rand(2, 5).astype("float32") x = np.random.rand(2, 5).astype("float32")
xx = TensorWrapper(x) xx = TensorWrapper(x)
np.testing.assert_almost_equal(xx.T.numpy(), x.T) np.testing.assert_almost_equal(xx.T.numpy(), x.T)
def test_as_type():
x = TensorWrapper([1, 2, 3], dtype=np.float32)
y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2))
np.testing.assert_almost_equal(get_scale(z.dtype), 0.2)
a = z.astype(quint8(0.3, 127))
np.testing.assert_almost_equal(get_scale(a.dtype), 0.3)
np.testing.assert_equal(get_zero_point(a.dtype), 127)
b = a.astype(quint8(0.3, 128))
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
np.testing.assert_equal(get_zero_point(b.dtype), 128)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册