提交 caf77d00 编写于 作者: M Megvii Engine Team

feat(mge/module): add quantize dtype load support for module load_state_dict

GitOrigin-RevId: 0a94cb6b17005dd5e6a81c6a4b8c51c7044d2751
上级 dedb7a3f
...@@ -235,6 +235,14 @@ class Tensor: ...@@ -235,6 +235,14 @@ class Tensor:
return self.__val.dtype return self.__val.dtype
return self._symvar.dtype return self._symvar.dtype
def set_dtype(self, dtype: str = None):
r"""Set the data type of the tensor.
"""
if self.__val is not None:
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
elif self.__sym is not None:
self.__sym = self.__sym.astype(dtype)
@property @property
def _comp_node(self): def _comp_node(self):
if self.__val is not None: if self.__val is not None:
......
...@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union ...@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np import numpy as np
from .._internal.dtype import is_quantize
from ..core import Buffer, Parameter, Tensor from ..core import Buffer, Parameter, Tensor
from ..logger import get_logger from ..logger import get_logger
...@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): ...@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta):
), "param `{}` shape mismatch, should be {}, get {}".format( ), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape k, var.shape, to_be_load.shape
) )
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var.set_dtype(to_be_load.dtype)
var.set_value(to_be_load) var.set_value(to_be_load)
loaded.append(k) loaded.append(k)
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
import megengine._internal as mgb
def test_wrong_dtype(): def test_wrong_dtype():
...@@ -26,3 +27,48 @@ def test_tensor_routine(): ...@@ -26,3 +27,48 @@ def test_tensor_routine():
mge.tensor([1]) mge.tensor([1])
mge.tensor(1.5) mge.tensor(1.5)
def test_tensor_set_dtype():
def check_dtype_value(tensor, dtype_scale, value):
if mgb.dtype.is_quantize(tensor.dtype):
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5:
raise AssertionError(
"compare scale failed expect {} got {}".format(
dtype_scale, mgb.dtype.get_scale(tensor.dtype)
)
)
if np.abs(tensor.numpy()[0][0] - value) > 1e-5:
raise AssertionError(
"compare value failed expect {} got {}".format(
tensor.numpy()[0][0], value
)
)
t = mge.Parameter(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)
t = mge.Buffer(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)
t = mge.Buffer(np.ones((3, 4), dtype="float32"))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.2))
check_dtype_value(s, 0.2, 10)
t.set_dtype(mgb.dtype.qint8(0.3))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(s, 0.1, 18)
s.set_dtype("float32")
check_dtype_value(s, 0, 1.8)
...@@ -14,8 +14,10 @@ import pytest ...@@ -14,8 +14,10 @@ import pytest
from helpers import MLP from helpers import MLP
import megengine as mge import megengine as mge
import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -347,3 +349,38 @@ def test_dump_model(): ...@@ -347,3 +349,38 @@ def test_dump_model():
pred = mlp(data) pred = mlp(data)
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
mge.dump(pred, f.name) mge.dump(pred, f.name)
def test_load_quantized():
data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(mgb.dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy()
)
mlp.eval()
pred0 = mlp(data)
with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)
assertTensorClose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册