提交 caf77d00 编写于 作者: M Megvii Engine Team

feat(mge/module): add quantize dtype load support for module load_state_dict

GitOrigin-RevId: 0a94cb6b17005dd5e6a81c6a4b8c51c7044d2751
上级 dedb7a3f
......@@ -235,6 +235,14 @@ class Tensor:
return self.__val.dtype
return self._symvar.dtype
def set_dtype(self, dtype: str = None):
r"""Set the data type of the tensor.
"""
if self.__val is not None:
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
elif self.__sym is not None:
self.__sym = self.__sym.astype(dtype)
@property
def _comp_node(self):
if self.__val is not None:
......
......@@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np
from .._internal.dtype import is_quantize
from ..core import Buffer, Parameter, Tensor
from ..logger import get_logger
......@@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta):
), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape
)
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var.set_dtype(to_be_load.dtype)
var.set_value(to_be_load)
loaded.append(k)
......
......@@ -10,6 +10,7 @@ import numpy as np
import pytest
import megengine as mge
import megengine._internal as mgb
def test_wrong_dtype():
......@@ -26,3 +27,48 @@ def test_tensor_routine():
mge.tensor([1])
mge.tensor(1.5)
def test_tensor_set_dtype():
def check_dtype_value(tensor, dtype_scale, value):
if mgb.dtype.is_quantize(tensor.dtype):
if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5:
raise AssertionError(
"compare scale failed expect {} got {}".format(
dtype_scale, mgb.dtype.get_scale(tensor.dtype)
)
)
if np.abs(tensor.numpy()[0][0] - value) > 1e-5:
raise AssertionError(
"compare value failed expect {} got {}".format(
tensor.numpy()[0][0], value
)
)
t = mge.Parameter(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)
t = mge.Buffer(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(t, 0.1, 10)
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
check_dtype_value(t, 0.3, 3)
t = mge.Buffer(np.ones((3, 4), dtype="float32"))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.2))
check_dtype_value(s, 0.2, 10)
t.set_dtype(mgb.dtype.qint8(0.3))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.1))
check_dtype_value(s, 0.1, 18)
s.set_dtype("float32")
check_dtype_value(s, 0, 1.8)
......@@ -14,8 +14,10 @@ import pytest
from helpers import MLP
import megengine as mge
import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose
......@@ -347,3 +349,38 @@ def test_dump_model():
pred = mlp(data)
with tempfile.NamedTemporaryFile() as f:
mge.dump(pred, f.name)
def test_load_quantized():
data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(mgb.dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy()
)
mlp.eval()
pred0 = mlp(data)
with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)
assertTensorClose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册