提交 ae8c3c81 编写于 作者: M Megvii Engine Team

feat(mge/functional): add python wrapper for fake quant opr

GitOrigin-RevId: 5f205198be9af2990d51984e1dac582b157d56a0
上级 b60cc8ca
......@@ -9,7 +9,12 @@ from enum import Enum
from functools import partial, update_wrapper, wraps
from typing import Dict
import numpy as np
from .. import functional as F
from ..core.ops import builtin
from ..core.tensor import megbrain_graph
from ..core.tensor.core import apply
from ..core.tensor.dtype import _metadata_dict
from ..core.tensor.function import Function
from ..tensor import Tensor
......@@ -81,16 +86,20 @@ def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor
"""
scale = q_dict["scale"]
zero_point = 0
zero_point = Tensor([0.0], dtype=np.float32)
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# Clip
oup = F.minimum(F.maximum(oup, qmin), qmax)
# Dequant
oup = (oup - zero_point) * scale
return oup
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
assert isinstance(
scale, (Tensor, megbrain_graph.VarNode)
), "scale must be Tensor type"
assert isinstance(
zero_point, (Tensor, megbrain_graph.VarNode)
), "zero point must be Tensor type"
op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
return apply(op, inp, scale, zero_point)[0]
def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
......
......@@ -11,8 +11,12 @@ import pytest
import megengine as mge
from megengine import tensor
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.function import Function
from megengine.core.tensor.utils import make_shape_tuple
from megengine.quantization.fake_quant import TQT_Function
from megengine.quantization.internal_fake_quant import *
from megengine.quantization.utils import QuantMode, fake_quant_tensor
class numpy_TQT_Function:
......@@ -77,3 +81,65 @@ def test_TQT():
check_inp(a, b, b, a_np, b_np, b_np)
def _save_to(self, name="grad"):
def callback(tensor, grad):
setattr(self, name, grad)
return callback
class Round(Function):
def forward(self, x):
return F.round(x)
def backward(self, output_grads):
return output_grads
def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax):
oup = Round()(inp / scale) + zero_point
oup = F.minimum(F.maximum(oup, qmin), qmax)
oup = (oup - zero_point) * scale
return oup
def test_fakequant():
qmin = -126
qmax = 129
def run(zero_point, scale):
q_dict = {}
q_dict["mode"] = QuantMode.ASYMMERTIC
q_dict["scale"] = scale
q_dict["zero_point"] = zero_point
inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32))
inp = tensor(inp_data, dtype=np.float32)
# test forward
oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy()
oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy()
assert np.allclose(oup, oup_gt)
assert oup.shape == oup_gt.shape
# test backward
x = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qmin, qmax, q_dict)
grad(y, tensor(F.ones_like(x)))
x1 = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x1, callback=_save_to(x1))
y1 = fake_quant_tensor_gt(x1, scale, zero_point, qmin, qmax)
grad(y1, tensor(F.ones_like(x1)))
assert np.allclose(x.grad.numpy(), x1.grad.numpy())
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape)
zero_point = tensor([1.0], dtype=np.float32)
scale = tensor([4.0], dtype=np.float32)
run(zero_point, scale)
zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32)
scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32)
run(zero_point, scale)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册