提交 a1ca50c9 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add name for quantized module

GitOrigin-RevId: edefbec7b70953144105c558bae34e3f792c02ec
上级 d0f70a44
......@@ -641,6 +641,7 @@ class DeformableConv2d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
......@@ -657,6 +658,7 @@ class DeformableConv2d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)
def _get_fanin(self):
......
......@@ -21,8 +21,9 @@ class DeformablePSROIPooling(Module):
sample_per_part,
spatial_scale,
trans_std: float = 0.1,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
self.no_trans = no_trans
self.part_size = part_size
self.pooled_h = pooled_h
......
......@@ -69,7 +69,17 @@ class Module(metaclass=ABCMeta):
Base Module class.
"""
def __init__(self, name=""):
def __init__(self, name=None):
"""
:param name: module's name, can be initialized by the ``kwargs`` parameter
of child class.
"""
if name is not None:
assert (
isinstance(name, str) and name.strip()
), "Module's name must be a non-empty string"
self.name = name
# runtime attributes
......@@ -109,7 +119,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs):
auto_naming.push_scope(self.name if self.name else self._name)
auto_naming.push_scope(self.name if self.name is not None else self._name)
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
......
......@@ -28,6 +28,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule):
float_module.in_features,
float_module.out_features,
float_module.bias is not None,
name=float_module.name,
)
qat_module.weight = float_module.weight
qat_module.bias = float_module.bias
......
......@@ -27,4 +27,4 @@ class Concat(Float.Concat, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
return cls(name=float_module.name)
......@@ -43,6 +43,7 @@ class Conv2d(Float.Conv2d, QATModule):
float_module.bias is not None,
float_module.conv_mode,
float_module.compute_mode,
name=float_module.name,
)
qat_module.weight = float_module.weight
qat_module.bias = float_module.bias
......
......@@ -155,6 +155,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
float_module.conv.bias is not None,
float_module.conv.conv_mode,
float_module.conv.compute_mode,
name=float_module.name,
)
qat_module.conv.weight = float_module.conv.weight
qat_module.conv.bias = float_module.conv.bias
......
......@@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls(float_module.method)
return cls(float_module.method, name=float_module.name)
......@@ -36,7 +36,9 @@ class Linear(Float.Linear, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qmod = cls(float_module.in_features, float_module.out_features)
qmod = cls(
float_module.in_features, float_module.out_features, name=float_module.name
)
qmod.weight = float_module.weight
qmod.bias = float_module.bias
return qmod
......@@ -26,8 +26,8 @@ class QATModule(Module):
with_weight = True
with_act = True
def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
......
......@@ -26,7 +26,7 @@ class QuantStub(Float.QuantStub, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
return cls(name=float_module.name)
class DequantStub(Float.DequantStub, QATModule):
......@@ -47,4 +47,4 @@ class DequantStub(Float.DequantStub, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
return cls(name=float_module.name)
......@@ -61,13 +61,14 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
qat_module.out_features,
qat_module.bias is not None,
dtype=output_dtype,
name=qat_module.name,
)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
weight = expand_dims(weight, [-1, -2])
qbmm.weight = Parameter(weight.numpy())
qbmm.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
if qat_module.bias is not None:
bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1))
qbmm.bias = Parameter(bias.numpy())
qbmm.bias = Parameter(bias.numpy(), name=qat_module.bias.name)
else:
qbmm.bias = Parameter(
np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32)
......
......@@ -18,8 +18,8 @@ class Concat(QuantizedModule):
A :class:`~.QuantizedModule` to do quantized :func:`~.concat`, used for inference only.
"""
def __init__(self, dtype=None):
super().__init__()
def __init__(self, dtype=None, **kwargs):
super().__init__(**kwargs)
self.output_dtype = dtype
def forward(self, inps: Iterable[Tensor], axis: int = 0):
......@@ -32,4 +32,4 @@ class Concat(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())
return cls(qat_module.get_activation_dtype(), name=qat_module.name)
......@@ -37,6 +37,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
dtype=None,
**kwargs
):
super().__init__(
in_channels,
......@@ -86,11 +87,12 @@ class Conv2d(Float.Conv2d, QuantizedModule):
qat_module.dilation,
qat_module.groups,
dtype=output_dtype,
name=qat_module.name,
)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
if qat_module.bias is not None:
qconv.bias = Parameter(qat_module.bias.numpy())
qconv.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
else:
qconv.bias = Parameter(
np.zeros(qat_module._infer_bias_shape(), dtype=np.float32)
......
......@@ -33,13 +33,14 @@ class _ConvBnActivation2d(Conv2d):
qat_module.conv.dilation,
qat_module.conv.groups,
dtype=output_dtype,
name=qat_module.name,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name)
qconv.bias = Parameter(b_fold.numpy(), name=qat_module.conv.bias.name)
return qconv
......
......@@ -14,8 +14,8 @@ from .module import QuantizedModule
class Elemwise(QuantizedModule):
r"""Quantized version of :class:`~.qat.Elemwise`."""
def __init__(self, method, dtype=None):
super().__init__()
def __init__(self, method, dtype=None, **kwargs):
super().__init__(**kwargs)
self.method = "Q" + method
self.output_dtype = dtype
......@@ -30,4 +30,6 @@ class Elemwise(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.method, qat_module.get_activation_dtype())
return cls(
qat_module.method, qat_module.get_activation_dtype(), name=qat_module.name
)
......@@ -17,8 +17,8 @@ from .module import QuantizedModule
class Linear(QuantizedModule):
r"""Quantized version of :class:`~.qat.Linear`."""
def __init__(self, dtype: np.dtype = None):
super().__init__()
def __init__(self, dtype: np.dtype = None, **kwargs):
super().__init__(**kwargs)
self.weight = None
self.bias = None
self.output_dtype = dtype
......@@ -44,9 +44,9 @@ class Linear(QuantizedModule):
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qmod = cls(dtype=output_dtype)
qmod = cls(dtype=output_dtype, name=qat_module.name)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qmod.weight = Parameter(weight.numpy())
qmod.weight = Parameter(weight.numpy(), name=qat_module.weight.name)
if qat_module.bias is not None:
qmod.bias = Parameter(qat_module.bias.numpy())
qmod.bias = Parameter(qat_module.bias.numpy(), name=qat_module.bias.name)
return qmod
......@@ -15,8 +15,8 @@ class QuantStub(QuantizedModule):
will convert input to quantized dtype.
"""
def __init__(self, dtype=None):
super().__init__()
def __init__(self, dtype=None, **kwargs):
super().__init__(**kwargs)
self.output_dtype = dtype
def forward(self, inp):
......@@ -28,7 +28,7 @@ class QuantStub(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())
return cls(qat_module.get_activation_dtype(), name=qat_module.name)
class DequantStub(QuantizedModule):
......@@ -46,4 +46,4 @@ class DequantStub(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls()
return cls(name=qat_module.name)
......@@ -17,6 +17,7 @@ import megengine.utils.comp_graph_tools as cgtools
from megengine import Parameter, Tensor
from megengine.core.tensor import megbrain_graph as G
from megengine.jit.tracing import trace
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.utils.naming import auto_naming
......@@ -29,14 +30,14 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
func.dump(
file,
optimize_for_inference=False,
arg_names="x",
arg_names=("x",),
keep_opr_name=keep_opr_name,
keep_var_name=2,
)
file.seek(0)
*_, outputs = G.load_graph(file)
op = cgtools.get_oprs_seq(outputs)[-1]
return op
ops = cgtools.get_oprs_seq(outputs)
return ops
@pytest.mark.parametrize("symbolic", [False, True])
......@@ -50,7 +51,7 @@ def test_auto_naming(symbolic):
return x + x
m = Simple("simple")
op = _dump_and_load(m, symbolic)
op = _dump_and_load(m, symbolic)[-1]
assert op.name == "simple.ADD"
assert op.outputs[0].name == "simple.ADD"
......@@ -70,7 +71,7 @@ def test_user_named_tensor(symbolic):
m = Simple("simple")
op = _dump_and_load(m, symbolic)
op = _dump_and_load(m, symbolic)[-1]
assert op.name == "simple.ADD"
assert op.outputs[0].name == "o_x"
......@@ -88,7 +89,7 @@ def test_user_named_param(symbolic):
m = Simple("simple")
op = _dump_and_load(m, symbolic)
op = _dump_and_load(m, symbolic)[-1]
assert op.inputs[0].name == "x"
assert op.inputs[1].name == "simple.k"
......@@ -98,7 +99,7 @@ def test_without_module(symbolic):
def f(x):
return 2 * x
op = _dump_and_load(f, symbolic)
op = _dump_and_load(f, symbolic)[-1]
assert op.name == "MUL"
......@@ -116,10 +117,10 @@ def test_with_submodule(symbolic):
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.linear.ADD"
assert op.inputs[0].owner.name == "simple.linear.MatrixMul"
assert op.outputs[0].name == "simple.linear.ADD"
ops = _dump_and_load(m, symbolic)
assert ops[-1].name == "simple.linear.ADD"
assert ops[-2].name == "simple.linear.MatrixMul"
assert ops[-1].outputs[0].name == "simple.linear.ADD"
@pytest.mark.parametrize("symbolic", [False, True])
......@@ -136,10 +137,10 @@ def test_named_submodule(symbolic):
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.x.ADD"
assert op.inputs[0].owner.name == "simple.x.MatrixMul"
assert op.outputs[0].name == "simple.x.ADD"
ops = _dump_and_load(m, symbolic)
assert ops[-1].name == "simple.x.ADD"
assert ops[-2].name == "simple.x.MatrixMul"
assert ops[-1].outputs[0].name == "simple.x.ADD"
@pytest.mark.parametrize("symbolic", [False, True])
......@@ -156,14 +157,111 @@ def test_with_same_operators(symbolic):
m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.RELU[1]"
assert op.inputs[0].owner.name == "simple.RELU[0]"
ops = _dump_and_load(m, symbolic)
assert ops[-1].name == "simple.RELU[1]"
assert ops[-2].name == "simple.RELU[0]"
def test_not_keep_opr_name():
def f(x):
return 2 * x
op = _dump_and_load(f, True, False)
op = _dump_and_load(f, True, False)[-1]
assert op.name == "MUL(x,2[2])[4]"
@pytest.mark.parametrize("symbolic", [False, True])
def test_quantized_module_auto_naming(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__(name=name)
self.quant = M.QuantStub()
self.linear = M.Linear(3, 3, bias=True)
self.dequant = M.DequantStub()
def forward(self, x):
out = self.quant(x)
out = self.linear(out)
out = self.dequant(out)
return out
m = Simple("simple")
quantize_qat(m)
quantize(m)
m.eval()
ops = _dump_and_load(m, symbolic)
ops_name = (
"x",
"simple.quant.TypeCvt",
"simple.linear.MatrixMul",
"simple.linear.ADD",
"simple.linear.TypeCvt",
"simple.dequant.TypeCvt",
)
for op, name in zip(ops, ops_name):
assert op.name == name
@pytest.mark.parametrize("symbolic", [False, True])
def test_quantized_module_user_naming(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__(name=name)
self.quant = M.QuantStub()
self.linear = M.Linear(3, 3, bias=True, name="user-linear")
self.dequant = M.DequantStub()
def forward(self, x):
out = self.quant(x)
out = self.linear(out)
out = self.dequant(out)
return out
m = Simple("simple")
quantize_qat(m)
quantize(m)
m.eval()
ops = _dump_and_load(m, symbolic)
ops_name = (
"x",
"simple.quant.TypeCvt",
"simple.user-linear.MatrixMul",
"simple.user-linear.ADD",
"simple.user-linear.TypeCvt",
"simple.dequant.TypeCvt",
)
for op, name in zip(ops, ops_name):
assert op.name == name
@pytest.mark.parametrize("symbolic", [False, True])
def test_quantized_module_user_naming_param(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__(name=name)
self.quant = M.QuantStub()
self.linear = M.Linear(3, 3, bias=True)
self.dequant = M.DequantStub()
self.linear.weight.name = "user-weight"
self.linear.bias.name = "user-bias"
def forward(self, x):
out = self.quant(x)
out = self.linear(out)
out = self.dequant(out)
return out
m = Simple("simple")
quantize_qat(m)
quantize(m)
m.eval()
ops = _dump_and_load(m, symbolic)
(matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"]
for var in matrix_mul_op.inputs:
assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight")
# BUG bias' name does not meet expectations because of astype operator after quantization
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册