提交 f63bca03 编写于 作者: M Megvii Engine Team

test(mge/dtype): add quant and lowbit dtype test

GitOrigin-RevId: 97ca1a393a7dc21e2c65d2cb9c38f7cb6b8581e0
上级 b80fade3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
import numpy as np
from megengine.core.tensor.dtype import bfloat16
from megengine.core.tensor.raw_tensor import as_raw_tensor
def test_define():
np.testing.assert_allclose(
np.array([0.5, 0.13425, 3.4687, -1.34976, -9.34673, 0.0], dtype=bfloat16),
np.array([0.5, 0.133789, 3.46875, -1.351562, -9.375, 0.0], dtype=np.float32),
atol=1e-6,
)
def test_cast():
dtypes = [np.int8, np.int16, np.int32, np.float32, np.float64]
fp32_values = [0.34985, 10.943, -0.5, -19.3, 21.49673]
bf16_values = [0.349609, 10.9375, -0.5, -19.25, 21.5]
int_values = [34, 10, -5, -19, 21]
for dtype in dtypes:
np.testing.assert_allclose(
np.array(fp32_values, dtype=bfloat16).astype(dtype),
np.array(bf16_values, dtype=dtype),
atol=1e-6,
)
np.testing.assert_allclose(
np.array(int_values, dtype=dtype),
np.array(int_values, dtype=bfloat16).astype(dtype),
atol=1e-6,
)
def test_shared_nd():
data = np.array([-3.4, 1.394683, 2.323497, -7.439948, -5.2397], dtype=bfloat16)
snd = as_raw_tensor(data, dtype=bfloat16, device="xpux")
assert snd.numpy().dtype == bfloat16
np.testing.assert_allclose(
snd.numpy(), [-3.40625, 1.398438, 2.328125, -7.4375, -5.25], atol=1e-6
)
data = np.array([-9.34964, -8.342, 9.4385, 0.18746, 1.48], dtype=bfloat16)
snd = as_raw_tensor(data, dtype=bfloat16, device="xpux")
np.testing.assert_allclose(
snd.numpy(), [-9.375, -8.3125, 9.4375, 0.1875, 1.476562], atol=1e-6
)
def test_pickle():
x = np.ascontiguousarray(np.random.rand(8192), dtype=bfloat16)
pkl = pickle.dumps(x, pickle.HIGHEST_PROTOCOL)
y = pickle.loads(pkl)
assert x.dtype is y.dtype
np.testing.assert_allclose(x.astype(np.float32), y.astype(np.float32), atol=1e-6)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
import numpy as np
import pytest
from megengine.core.tensor.dtype import intb1, intb2, intb4
from megengine.core.tensor.raw_tensor import as_raw_tensor
def bit_define_test(bit, low_bit_type):
max_value = (1 << bit) - 1
min_value = 1 - (1 << bit)
a = np.array([i for i in range(min_value, max_value + 2, 2)], dtype=low_bit_type)
for i in range(max_value + 1):
np.testing.assert_equal(a[i], i * 2 - max_value)
np.testing.assert_equal(str(a[i]), str(i * 2 - max_value))
with pytest.raises(ValueError):
np.arange(min_value, max_value, dtype=low_bit_type)
with pytest.raises(ValueError):
np.arange(min_value - 2, max_value + 4, 2, dtype=low_bit_type)
np.testing.assert_allclose(
np.arange(min_value, 12, 2, dtype=low_bit_type),
(np.arange((13 - min_value) // 2, dtype=np.int8) % (max_value + 1)) * 2
- max_value,
)
np.testing.assert_allclose(
np.arange(max_value, max_value - 20, -2, dtype=low_bit_type),
(np.arange(max_value, max_value - 10, -1, dtype=np.int8) % (max_value + 1)) * 2
- max_value,
)
def test_define():
bit_define_test(1, intb1)
bit_define_test(2, intb2)
bit_define_test(4, intb4)
def _bit_cast_test(bit, low_bit_type):
dtypes = [np.int8, np.int16, np.int32, np.float32, np.float64]
max_value = (1 << bit) - 1
min_value = 1 - (1 << bit)
for dtype in dtypes:
np.testing.assert_allclose(
np.arange(min_value, max_value + 2, 2, dtype=low_bit_type).astype(dtype),
np.arange(min_value, max_value + 2, 2, dtype=dtype),
)
with pytest.raises(ValueError):
np.array([2, 1, -1], dtype=int).astype(low_bit_type)
with pytest.raises(ValueError):
np.array([min_value - 2, 1, max_value + 2], dtype=int).astype(low_bit_type)
def test_cast():
_bit_cast_test(1, intb1)
_bit_cast_test(2, intb2)
_bit_cast_test(4, intb4)
def _shared_nd_test(bit, low_bit_type):
max_value = (1 << bit) - 1
min_value = 1 - (1 << bit)
data = np.arange(min_value, max_value + 2, 2, dtype=low_bit_type)
snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux")
np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 2))
data = np.arange(min_value, max_value + 2, 4, dtype=low_bit_type)
snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux")
np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 4))
def test_shared_nd():
_shared_nd_test(1, intb1)
_shared_nd_test(2, intb2)
_shared_nd_test(4, intb4)
def test_pickle():
x = np.ascontiguousarray(np.random.randint(2, size=8192) * 2 - 1, dtype=intb1)
pkl = pickle.dumps(x, pickle.HIGHEST_PROTOCOL)
y = pickle.loads(pkl)
assert x.dtype is y.dtype
np.testing.assert_allclose(x.astype(np.float32), y.astype(np.float32))
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
import numpy as np
import pytest
import megengine.core.tensor.megbrain_graph as G
from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply
from megengine.core.tensor.dtype import (
_metadata_dict,
convert_from_qint4,
convert_from_qint8,
convert_from_quint4,
convert_from_quint8,
convert_to_qint4,
convert_to_qint8,
convert_to_quint4,
convert_to_quint8,
get_scale,
get_zero_point,
is_quantize,
qint4,
qint8,
quint4,
quint8,
)
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.distributed.helper import get_device_count_by_fork
def test_dtype_quint8():
with pytest.raises(ValueError):
blah = quint8(0.05, 0.233)
with pytest.raises(ValueError):
blah = quint8(0.02, 777)
with pytest.raises(ValueError):
blah = quint8(0.02, -1)
dt = quint8(0.01, 135)
assert isinstance(dt, np.dtype)
assert "mgb_dtype" in dt.metadata
np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01)
np.testing.assert_equal(dt.metadata["mgb_dtype"]["zero_point"], 135)
assert is_quantize(dt)
np.testing.assert_allclose(get_scale(dt), 0.01)
np.testing.assert_equal(get_zero_point(dt), 135)
def test_dtype_qint8():
dt = qint8(0.01)
assert isinstance(dt, np.dtype)
assert "mgb_dtype" in dt.metadata
np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01)
assert is_quantize(dt) == True
np.testing.assert_allclose(get_scale(dt), 0.01)
def _get_compiled_result(inp, dtype, shape, device, calc_func=None):
graph = G.Graph()
# graph.options.async_exec_level = 0b100
inp_node = G.InputNode(device=device, dtype=dtype, shape=shape, graph=graph)
temp_rst = calc_func(inp_node.outputs[0])
oup_node = G.OutputNode(temp_rst)
func = graph.compile(oup_node.outputs[0])
inp_node.set_value(as_raw_tensor(inp, dtype=dtype, device=device)._dev_tensor())
func.execute()
return oup_node.get_value().numpy()
def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True):
metadata = _metadata_dict[dtype_str]
assert "mgb_dtype" in oup.dtype.metadata
assert is_quantize(oup.dtype)
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name)
np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype))
if is_unsigned:
np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype))
def test_dtype_int8_ffi_handle():
device = "xpux"
shape = (3, 3, 3)
data = np.random.random(shape).astype(np.float32) * 5 - 1
def identity(x):
return x
dtype = quint8(0.01, 127)
inp = convert_to_quint8(data, dtype)
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity)
_check_result_attr(oup, dtype, "quint8")
np.testing.assert_allclose(convert_from_quint8(oup), convert_from_quint8(inp))
dtype = qint8(0.01)
inp = convert_to_qint8(data, dtype)
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity)
_check_result_attr(oup, dtype, "qint8", is_unsigned=False)
np.testing.assert_allclose(convert_from_qint8(oup), convert_from_qint8(inp))
def test_quint8_typecvt():
device = "xpux"
shape = (3, 3, 3)
data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(param=dt), x)
return y
# convert to quint8
dtype = quint8(0.01, 135)
oup = _get_compiled_result(
data, np.float32, shape, device, calc_func=partial(typecvt, dt=dtype)
)
_check_result_attr(oup, dtype, "quint8")
np.testing.assert_equal(oup, convert_to_quint8(data, dtype))
# convert from quint8 to float32
oup_float = _get_compiled_result(
oup, dtype, shape, device, calc_func=partial(typecvt, dt=np.float32)
)
assert oup_float.dtype == np.float32
np.testing.assert_equal(
oup_float, convert_from_quint8(convert_to_quint8(data, dtype))
)
def test_dtype_quint4():
with pytest.raises(ValueError):
blah = quint4(0.05, 0.233)
with pytest.raises(ValueError):
blah = quint4(0.02, 18)
with pytest.raises(ValueError):
blah = quint4(0.02, -1)
dt = quint4(0.01, 8)
assert isinstance(dt, np.dtype)
assert "mgb_dtype" in dt.metadata
np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01)
np.testing.assert_equal(dt.metadata["mgb_dtype"]["zero_point"], 8)
assert is_quantize(dt)
np.testing.assert_allclose(get_scale(dt), 0.01)
np.testing.assert_equal(get_zero_point(dt), 8)
def test_dtype_qint4():
dt = qint4(0.01)
assert isinstance(dt, np.dtype)
assert "mgb_dtype" in dt.metadata
np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01)
assert is_quantize(dt)
np.testing.assert_allclose(get_scale(dt), 0.01)
def test_dtype_int4_ffi_handle():
device = "xpux"
shape = (3, 3, 3)
data = np.random.random(shape).astype(np.float32) * 5 - 1
print(data)
def identity(x):
return x
dtype = quint4(0.01, 7)
inp = convert_to_quint4(data, dtype)
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity)
_check_result_attr(oup, dtype, "quint4")
np.testing.assert_allclose(convert_from_quint4(oup), convert_from_quint4(inp))
dtype = qint4(0.01)
inp = convert_to_qint4(data, dtype)
oup = _get_compiled_result(inp, dtype, shape, device, calc_func=identity)
_check_result_attr(oup, dtype, "qint4", is_unsigned=False)
np.testing.assert_allclose(convert_from_qint4(oup), convert_from_qint4(inp))
@pytest.mark.skipif(
get_device_count_by_fork("gpu") != 0,
reason="TypeCvt to quint4 is not supported on GPU",
)
def test_quint4_typecvt():
device = "xpux"
shape = (3, 3, 3)
data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(param=dt), x)
return y
# convert to quint4
dtype = quint4(0.01, 5)
oup = _get_compiled_result(
data, np.float32, shape, device, calc_func=partial(typecvt, dt=dtype)
)
_check_result_attr(oup, dtype, "quint4")
np.testing.assert_equal(oup, convert_to_quint4(data, dtype))
# convert from quint4 to float32
oup_float = _get_compiled_result(
oup, dtype, shape, device, calc_func=partial(typecvt, dt=np.float32)
)
assert oup_float.dtype == np.float32
np.testing.assert_equal(
oup_float, convert_from_quint4(convert_to_quint4(data, dtype))
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册