diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 6f77a54a10b108e1f1ddc2c587f3604f9ffa92e8..284ce8c52faef22869ab3dc9a9d277d67329a260 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -9,7 +9,12 @@ from enum import Enum from functools import partial, update_wrapper, wraps from typing import Dict +import numpy as np + from .. import functional as F +from ..core.ops import builtin +from ..core.tensor import megbrain_graph +from ..core.tensor.core import apply from ..core.tensor.dtype import _metadata_dict from ..core.tensor.function import Function from ..tensor import Tensor @@ -81,16 +86,20 @@ def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor """ scale = q_dict["scale"] - zero_point = 0 + zero_point = Tensor([0.0], dtype=np.float32) if q_dict["mode"] == QuantMode.ASYMMERTIC: zero_point = q_dict["zero_point"] - # Quant - oup = Round()(inp / scale) + zero_point - # Clip - oup = F.minimum(F.maximum(oup, qmin), qmax) - # Dequant - oup = (oup - zero_point) * scale - return oup + + assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" + assert isinstance( + scale, (Tensor, megbrain_graph.VarNode) + ), "scale must be Tensor type" + assert isinstance( + zero_point, (Tensor, megbrain_graph.VarNode) + ), "zero point must be Tensor type" + + op = builtin.FakeQuant(qmin=qmin, qmax=qmax) + return apply(op, inp, scale, zero_point)[0] def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 3d36847e5357ea04e4f25fee31802947a611c3cb..59b366d579ffcda60d3894c677e06ec7e65363c8 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -11,8 +11,12 @@ import pytest import megengine as mge from megengine import tensor +from megengine.core.autodiff.grad import Grad +from megengine.core.tensor.function import Function +from megengine.core.tensor.utils import make_shape_tuple from megengine.quantization.fake_quant import TQT_Function from megengine.quantization.internal_fake_quant import * +from megengine.quantization.utils import QuantMode, fake_quant_tensor class numpy_TQT_Function: @@ -77,3 +81,65 @@ def test_TQT(): check_inp(a, b, b, a_np, b_np, b_np) + + +def _save_to(self, name="grad"): + def callback(tensor, grad): + setattr(self, name, grad) + + return callback + + +class Round(Function): + def forward(self, x): + return F.round(x) + + def backward(self, output_grads): + return output_grads + + +def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax): + oup = Round()(inp / scale) + zero_point + oup = F.minimum(F.maximum(oup, qmin), qmax) + oup = (oup - zero_point) * scale + return oup + + +def test_fakequant(): + qmin = -126 + qmax = 129 + + def run(zero_point, scale): + q_dict = {} + q_dict["mode"] = QuantMode.ASYMMERTIC + q_dict["scale"] = scale + q_dict["zero_point"] = zero_point + inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32)) + inp = tensor(inp_data, dtype=np.float32) + # test forward + oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy() + oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy() + assert np.allclose(oup, oup_gt) + assert oup.shape == oup_gt.shape + + # test backward + x = tensor(inp_data, dtype=np.float32) + grad = Grad().wrt(x, callback=_save_to(x)) + y = fake_quant_tensor(x, qmin, qmax, q_dict) + grad(y, tensor(F.ones_like(x))) + + x1 = tensor(inp_data, dtype=np.float32) + grad = Grad().wrt(x1, callback=_save_to(x1)) + y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax) + grad(y1, tensor(F.ones_like(x1))) + + assert np.allclose(x.grad.numpy(), x1.grad.numpy()) + assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) + + zero_point = tensor([1.0], dtype=np.float32) + scale = tensor([4.0], dtype=np.float32) + run(zero_point, scale) + + zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) + scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) + run(zero_point, scale)