提交 be511a56 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix tensor astype failed for quantized type

GitOrigin-RevId: 383458acbf18fa956ca1ccaa376255ff1b06735a
上级 b309890c
......@@ -62,6 +62,21 @@ def get_zero_point(dtype):
return metadata["zero_point"]
def is_equal(dt0, dt1):
def _get_zero_point(dtype):
assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"]
return metadata.get("zero_point")
if is_quantize(dt0) and is_quantize(dt1):
return get_scale(dt0) == get_scale(dt1) and _get_zero_point(
dt0
) == _get_zero_point(dt1)
if not (is_quantize(dt0) or is_quantize(dt1)):
return dt0 == dt1
return False
def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax
......
......@@ -14,6 +14,7 @@ import numpy as np
from ..ops import builtin
from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .dtype import is_equal
def dtype_promotion(inputs):
......@@ -112,7 +113,7 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype):
dtype = np.dtype(dtype)
if x.dtype != dtype:
if not is_equal(x.dtype, dtype):
(x,) = apply(builtin.TypeCvt(param=dtype), x)
return x
......
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.core.tensor.tensor_wrapper import TensorWrapper
......@@ -71,3 +72,17 @@ def test_transpose():
x = np.random.rand(2, 5).astype("float32")
xx = TensorWrapper(x)
np.testing.assert_almost_equal(xx.T.numpy(), x.T)
def test_as_type():
x = TensorWrapper([1, 2, 3], dtype=np.float32)
y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2))
np.testing.assert_almost_equal(get_scale(z.dtype), 0.2)
a = z.astype(quint8(0.3, 127))
np.testing.assert_almost_equal(get_scale(a.dtype), 0.3)
np.testing.assert_equal(get_zero_point(a.dtype), 127)
b = a.astype(quint8(0.3, 128))
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
np.testing.assert_equal(get_zero_point(b.dtype), 128)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册