From caf77d000edde883878d7968b298f9d2df160a3d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 2 Jul 2020 17:55:58 +0800 Subject: [PATCH] feat(mge/module): add quantize dtype load support for module load_state_dict GitOrigin-RevId: 0a94cb6b17005dd5e6a81c6a4b8c51c7044d2751 --- python_module/megengine/core/tensor.py | 8 ++++ python_module/megengine/module/module.py | 5 ++ python_module/test/unit/core/test_tensor.py | 46 +++++++++++++++++++ python_module/test/unit/module/test_module.py | 37 +++++++++++++++ 4 files changed, 96 insertions(+) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index 8aefcb31d..b8d6e1c95 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -235,6 +235,14 @@ class Tensor: return self.__val.dtype return self._symvar.dtype + def set_dtype(self, dtype: str = None): + r"""Set the data type of the tensor. + """ + if self.__val is not None: + self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) + elif self.__sym is not None: + self.__sym = self.__sym.astype(dtype) + @property def _comp_node(self): if self.__val is not None: diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index b55bdd894..af63cdcff 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union import numpy as np +from .._internal.dtype import is_quantize from ..core import Buffer, Parameter, Tensor from ..logger import get_logger @@ -460,6 +461,10 @@ class Module(metaclass=ABCMeta): ), "param `{}` shape mismatch, should be {}, get {}".format( k, var.shape, to_be_load.shape ) + # For quantized dtype, the initialized dtype + # scale/zero_points maybe invalid, use pretrained dtype instead. + if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): + var.set_dtype(to_be_load.dtype) var.set_value(to_be_load) loaded.append(k) diff --git a/python_module/test/unit/core/test_tensor.py b/python_module/test/unit/core/test_tensor.py index 3c8a47b01..5f8770760 100644 --- a/python_module/test/unit/core/test_tensor.py +++ b/python_module/test/unit/core/test_tensor.py @@ -10,6 +10,7 @@ import numpy as np import pytest import megengine as mge +import megengine._internal as mgb def test_wrong_dtype(): @@ -26,3 +27,48 @@ def test_tensor_routine(): mge.tensor([1]) mge.tensor(1.5) + + +def test_tensor_set_dtype(): + def check_dtype_value(tensor, dtype_scale, value): + if mgb.dtype.is_quantize(tensor.dtype): + if np.abs(mgb.dtype.get_scale(tensor.dtype) - dtype_scale) > 1e-5: + raise AssertionError( + "compare scale failed expect {} got {}".format( + dtype_scale, mgb.dtype.get_scale(tensor.dtype) + ) + ) + if np.abs(tensor.numpy()[0][0] - value) > 1e-5: + raise AssertionError( + "compare value failed expect {} got {}".format( + tensor.numpy()[0][0], value + ) + ) + + t = mge.Parameter(np.ones((3, 4), dtype="float32")) + t.set_dtype(mgb.dtype.qint8(0.1)) + check_dtype_value(t, 0.1, 10) + + t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) + t.set_dtype(mgb.dtype.qint8(0.3)) + check_dtype_value(t, 0.3, 3) + + t = mge.Buffer(np.ones((3, 4), dtype="float32")) + t.set_dtype(mgb.dtype.qint8(0.1)) + check_dtype_value(t, 0.1, 10) + + t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) + t.set_dtype(mgb.dtype.qint8(0.3)) + check_dtype_value(t, 0.3, 3) + + t = mge.Buffer(np.ones((3, 4), dtype="float32")) + s = t + 1 + s.set_dtype(mgb.dtype.qint8(0.2)) + check_dtype_value(s, 0.2, 10) + + t.set_dtype(mgb.dtype.qint8(0.3)) + s = t + 1 + s.set_dtype(mgb.dtype.qint8(0.1)) + check_dtype_value(s, 0.1, 18) + s.set_dtype("float32") + check_dtype_value(s, 0, 1.8) diff --git a/python_module/test/unit/module/test_module.py b/python_module/test/unit/module/test_module.py index 16aaf08f0..3790d2c43 100644 --- a/python_module/test/unit/module/test_module.py +++ b/python_module/test/unit/module/test_module.py @@ -14,8 +14,10 @@ import pytest from helpers import MLP import megengine as mge +import megengine._internal as mgb from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential +from megengine.quantization.quantize import quantize, quantize_qat from megengine.test import assertTensorClose @@ -347,3 +349,38 @@ def test_dump_model(): pred = mlp(data) with tempfile.NamedTemporaryFile() as f: mge.dump(pred, f.name) + + +def test_load_quantized(): + data_shape = (2, 28) + data = tensor(np.random.random(data_shape), dtype="float32") + data = data.astype(mgb.dtype.qint8(0.1)) + mlp = MLP() + quantize_qat(mlp) + quantize(mlp) + mlp.dense0.weight = Parameter( + mlp.dense0.weight.astype(mgb.dtype.qint8(0.001)).numpy() + ) + mlp.dense1.weight = Parameter( + mlp.dense1.weight.astype(mgb.dtype.qint8(0.0002)).numpy() + ) + mlp.eval() + pred0 = mlp(data) + + with BytesIO() as fout: + mge.save(mlp.state_dict(), fout) + fout.seek(0) + checkpoint = mge.load(fout) + # change mlp weight. + mlp.dense0.weight = Parameter( + mlp.dense0.weight.astype(mgb.dtype.qint8(0.00001)).numpy() + ) + mlp.dense1.weight = Parameter( + mlp.dense1.weight.astype(mgb.dtype.qint8(0.2)).numpy() + ) + mlp.load_state_dict(checkpoint) + pred1 = mlp(data) + + assertTensorClose( + pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 + ) -- GitLab