# -*- coding: utf-8 -*- from collections import OrderedDict from io import BytesIO import numpy as np import pytest import megengine as mge import megengine.functional as F from megengine import Parameter, Tensor, tensor from megengine.module import ( BatchNorm1d, BatchNorm2d, Conv1d, Conv2d, Dropout, Linear, MaxPool2d, Module, Sequential, Softmax, ) from megengine.module.module import _access_structure from megengine.quantization.quantize import quantize, quantize_qat from megengine.traced_module import TracedModule, trace_module from megengine.utils.module_utils import get_expand_structure, set_expand_structure class MLP(Module): def __init__(self): super().__init__() self.dense0 = Linear(28, 50) self.dense1 = Linear(50, 20) def forward(self, x): x = self.dense0(x) x = F.relu(x) x = self.dense1(x) return x class MyModule(Module): class InnerModule(Module): def __init__(self): super().__init__() self.bn = BatchNorm2d(4) def forward(self, x): return self.bn(x) def __init__(self): super().__init__() self.i = self.InnerModule() self.bn = BatchNorm2d(4) self.param = Parameter(np.ones(1, dtype=np.float32)) self.buff = Tensor(np.ones(1, dtype=np.float32)) def forward(self, x): x = self.i(x) x = self.bn(x) return x @pytest.mark.parametrize("test_traced_module", [True, False]) def test_module_api(test_traced_module): m = MyModule() if test_traced_module: buff = m.buff param = m.param m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) assert "buff" not in m.__dict__ assert "param" not in m.__dict__ m.buff = buff m.param = param assert list(m.children()) == [m.bn, m.i] assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] assert list(m.named_modules()) == [ ("", m), ("bn", m.bn), ("i", m.i), ("i.bn", m.i.bn), ] assert list(m.named_modules(prefix="x")) == [ ("x", m), ("x.bn", m.bn), ("x.i", m.i), ("x.i.bn", m.i.bn), ] assert list(m.buffers()) == [ m.bn.running_mean, m.bn.running_var, m.buff, m.i.bn.running_mean, m.i.bn.running_var, ] assert list(m.buffers(recursive=False)) == [m.buff] assert list(m.named_buffers()) == [ ("bn.running_mean", m.bn.running_mean), ("bn.running_var", m.bn.running_var), ("buff", m.buff), ("i.bn.running_mean", m.i.bn.running_mean), ("i.bn.running_var", m.i.bn.running_var), ] assert list(m.parameters()) == [ m.bn.bias, m.bn.weight, m.i.bn.bias, m.i.bn.weight, m.param, ] assert list(m.named_parameters()) == [ ("bn.bias", m.bn.bias), ("bn.weight", m.bn.weight), ("i.bn.bias", m.i.bn.bias), ("i.bn.weight", m.i.bn.weight), ("param", m.param), ] assert list(m.tensors()) == [ m.bn.bias, m.bn.running_mean, m.bn.running_var, m.bn.weight, m.buff, m.i.bn.bias, m.i.bn.running_mean, m.i.bn.running_var, m.i.bn.weight, m.param, ] assert list(m.named_tensors()) == [ ("bn.bias", m.bn.bias), ("bn.running_mean", m.bn.running_mean), ("bn.running_var", m.bn.running_var), ("bn.weight", m.bn.weight), ("buff", m.buff), ("i.bn.bias", m.i.bn.bias), ("i.bn.running_mean", m.i.bn.running_mean), ("i.bn.running_var", m.i.bn.running_var), ("i.bn.weight", m.i.bn.weight), ("param", m.param), ] m.eval() assert ( m.training == False and m.bn.training == False and m.i.training == False and m.i.bn.training == False ) m.bn.train() assert m.training == False and m.bn.training == True and m.i.bn.training == False m.eval() m.i.train() assert ( m.training == False and m.bn.training == False and m.i.training == True and m.i.bn.training == True ) m.eval() m.train() assert m.training == True and m.bn.training == True and m.i.bn.training == True def fn(m): m.training = False m.apply(fn) assert m.bn.training == False and m.i.bn.training == False @pytest.mark.parametrize("test_traced_module", [True, False]) def test_module_api_reuse_submodule(test_traced_module): m = MyModule() if test_traced_module: m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) m.h = m.i # pylint: disable=attribute-defined-outside-init assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] assert list(m.named_modules()) == [ ("", m), ("bn", m.bn), ("h", m.i), ("h.bn", m.i.bn), ] @pytest.mark.parametrize("test_traced_module", [True, False]) def test_module_api_iterable_stability(test_traced_module): m = MyModule() if test_traced_module: m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) l = list(m.modules()) for _ in range(100): assert list(m.modules()) == l @pytest.mark.parametrize("test_traced_module", [True, False]) def test_module_api_hooks(test_traced_module): net = MyModule() if test_traced_module: net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1)))) pre_hook_num = 0 post_hook_num = 0 hooks = [] def pre_hook(_, inputs): nonlocal pre_hook_num pre_hook_num += 1 modified_inputs = tuple(inp + 1 for inp in inputs) return modified_inputs def post_hook(_, __, outputs): nonlocal post_hook_num post_hook_num += 1 outputs += 1 return outputs net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook))) net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook))) shape = (1, 4, 1, 1) x = tensor(np.zeros(shape, dtype=np.float32)) y = net(x) assert pre_hook_num == 4 assert post_hook_num == 4 mean1 = Parameter(np.zeros(shape), dtype=np.float32) bn1 = F.batch_norm( x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True ) np.testing.assert_allclose( net.i.bn.running_mean.numpy(), mean1.numpy(), ) mean2 = Parameter(np.zeros(shape), dtype=np.float32) bn2 = F.batch_norm( bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True ) np.testing.assert_allclose( net.bn.running_mean.numpy(), mean2.numpy(), ) np.testing.assert_allclose((bn2 + 2).numpy(), y.numpy()) assert len(hooks) == 8 for handler in hooks: handler.remove() y = net(x) assert pre_hook_num == 4 assert post_hook_num == 4 class MyModule2(Module): class InnerModule(Module): def __init__(self): super().__init__() self.bn = BatchNorm2d(4) self.test_bool_key = {True: 1, False: 0} def forward(self, x): x = self.bn(x) def __init__(self): super().__init__() self.bn = BatchNorm2d(4) self.a = [ BatchNorm2d(4), {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0}, (self.InnerModule(),), ] def forward(self, x): return x def test_expand_structure(): m = MyModule2() rst = [ ("", m), ("a.0", m.a[0]), ("a.1.x", m.a[1]["x"]), ("a.1.y.0", m.a[1]["y"][0]), ("a.1.y.1", m.a[1]["y"][1]), ("a.1.y.1.bn", m.a[1]["y"][1].bn), ("a.2.0", m.a[2][0]), ("a.2.0.bn", m.a[2][0].bn), ("bn", m.bn), ] assert list(m.named_modules()) == rst for item in rst[1:]: assert get_expand_structure(m, item[0]) == item[1] for item in reversed(rst[1:]): if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)): continue set_expand_structure(m, item[0], "TEST_VALUE") assert get_expand_structure(m, item[0]) == "TEST_VALUE" def test_flatten_others(): def be_others(obj): return not isinstance(obj, (Tensor, Module)) m = MyModule2() assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0 def test_flatten_with_parent(): m = MyModule2() assert list(m.named_modules(with_parent=True)) == [ ("", m, None), ("a.0", m.a[0], m), ("a.1.x", m.a[1]["x"], m), ("a.1.y.0", m.a[1]["y"][0], m), ("a.1.y.1", m.a[1]["y"][1], m), ("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]), ("a.2.0", m.a[2][0], m), ("a.2.0.bn", m.a[2][0].bn, m.a[2][0]), ("bn", m.bn, m), ] assert list(m.modules(with_parent=True)) == [ (m, None), (m.a[0], m), (m.a[1]["x"], m), (m.a[1]["y"][0], m), (m.a[1]["y"][1], m), (m.a[1]["y"][1].bn, m.a[1]["y"][1]), (m.a[2][0], m), (m.a[2][0].bn, m.a[2][0]), (m.bn, m), ] class MyModule3(Module): class InnerModule(Module): def __init__(self): super().__init__() self.bn = BatchNorm2d(4) def forward(self, x): x = self.bn(x) def __init__(self): super().__init__() self.bn = BatchNorm2d(4) self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),) def forward(self, x): return x def test_module_api_with_sequential(): m = MyModule3() assert list(m.named_modules()) == [ ("", m), ("bn", m.bn), ("seq", m.seq), ("seq.0", m.seq[0]), ("seq.1", m.seq[1]), ("seq.1.bn", m.seq[1].bn), ] def test_sequential_named_children(): modules = OrderedDict() modules["name0"] = Linear(20, 10) modules["name1"] = Linear(10, 5) modules["name2"] = Linear(5, 1) m = Sequential(modules) l = list(m.named_children()) assert l[0][0] == "name0" assert l[1][0] == "name1" assert l[2][0] == "name2" def test_state_dict(): data_shape = (2, 28) data = tensor(np.random.random(data_shape)) mlp = MLP() pred0 = mlp(data) with BytesIO() as fout: mge.save(mlp.state_dict(), fout) fout.seek(0) state_dict = mge.load(fout) state_dict["extra"] = None mlp1 = MLP() mlp1.load_state_dict(state_dict, strict=False) pred1 = mlp1(data) np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) with pytest.raises(KeyError): mlp1.load_state_dict(state_dict) del state_dict["extra"] del state_dict["dense0.bias"] with pytest.raises(KeyError): mlp1.load_state_dict(state_dict) class AssertModule(Module): def __init__(self): super().__init__() self.error_tensor_key = {True: tensor([]), False: 0} def forward(self, x): return x def test_assert_message(): with pytest.raises( AssertionError, match="keys for Tensor and Module must be str, error key: True" ): m = AssertModule() list(m._flatten()) class Simple(Module): def __init__(self): super().__init__() self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False) self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False) self.conv1.weight = self.conv0.weight def forward(self, inputs): x = self.conv0(inputs) y = self.conv1(inputs) return x + y @pytest.mark.parametrize("test_traced_module", [True, False]) def test_shared_param(test_traced_module): net = Simple() if test_traced_module: net = trace_module(net, tensor(np.random.random((1, 1, 8, 8)))) assert net.conv0.weight is net.conv1.weight data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) with BytesIO() as f: mge.save(net, f) f.seek(0) net1 = mge.load(f) assert net1.conv0.weight is net1.conv1.weight np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy()) with BytesIO() as f: mge.save(net.conv0, f) f.seek(0) conv0 = mge.load(f) with BytesIO() as f: mge.save(net.conv1, f) f.seek(0) conv1 = mge.load(f) assert conv0.weight is not conv1.weight np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) class Simple2(Module): def __init__(self): super().__init__() self.conv1 = Conv1d(1, 1, kernel_size=3, bias=False) self.conv0 = Conv1d(1, 1, kernel_size=3, bias=False) self.conv1.weight = self.conv0.weight def forward(self, inputs): pass def test_shared_param_1d(): net = Simple2() assert net.conv0.weight is net.conv1.weight data = tensor(np.random.random((1, 1, 8)).astype(np.float32)) np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) with BytesIO() as f: mge.save(net, f) f.seek(0) net1 = mge.load(f) assert net1.conv0.weight is net1.conv1.weight np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy()) with BytesIO() as f: mge.save(net.conv0, f) f.seek(0) conv0 = mge.load(f) with BytesIO() as f: mge.save(net.conv1, f) f.seek(0) conv1 = mge.load(f) assert conv0.weight is not conv1.weight np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) @pytest.mark.parametrize("test_traced_module", [True, False]) def test_pickle_module(test_traced_module): data_shape = (2, 28) data = tensor(np.random.random(data_shape)) mlp = MLP() pred_gt = mlp(data) if test_traced_module: mlp = trace_module(mlp, data) # pickle before forward with BytesIO() as fout: mge.save(mlp, fout) fout.seek(0) mlp1 = mge.load(fout) if test_traced_module: assert type(mlp1) == TracedModule pred0 = mlp1(data) pred1 = mlp(data) # pickle after forward with BytesIO() as fout: mge.save(mlp, fout) fout.seek(0) mlp1 = mge.load(fout) if test_traced_module: assert type(mlp1) == TracedModule pred2 = mlp1(data) np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) def test_repr_basic(): # test whether __repr__ can output correct information class ConvModel(Module): def __init__(self): super().__init__() self.conv1 = Conv2d(3, 128, 3, padding=1, bias=False) self.conv2 = Conv2d(3, 128, 3, dilation=2, bias=False) self.bn1 = BatchNorm1d(128) self.bn2 = BatchNorm2d(128) self.pooling = MaxPool2d(kernel_size=2, padding=0) modules = OrderedDict() modules["depthwise"] = Conv2d(256, 256, 3, 1, 0, groups=256, bias=False,) modules["pointwise"] = Conv2d( 256, 256, kernel_size=1, stride=1, padding=0, bias=True, ) self.submodule1 = Sequential(modules) self.list1 = [Dropout(drop_prob=0.1), [Softmax(axis=100)]] self.tuple1 = ( Dropout(drop_prob=0.1), (Softmax(axis=100), Dropout(drop_prob=0.2)), ) self.dict1 = {"Dropout": Dropout(drop_prob=0.1)} self.fc1 = Linear(512, 1024) def forward(self, inputs): pass ground_truth = ( "ConvModel(\n" " (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" " (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" " (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n" " (submodule1): Sequential(\n" " (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)\n" " (pointwise): Conv2d(256, 256, kernel_size=(1, 1))\n" " )\n" " (list1.0): Dropout(drop_prob=0.1)\n" " (list1.1.0): Softmax(axis=100)\n" " (tuple1.0): Dropout(drop_prob=0.1)\n" " (tuple1.1.0): Softmax(axis=100)\n" " (tuple1.1.1): Dropout(drop_prob=0.2)\n" " (dict1.Dropout): Dropout(drop_prob=0.1)\n" " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n" ")" ) net = ConvModel() output = net.__repr__() assert output == ground_truth def test_repr_module_reassign(): # test whether __repr__ can deal with module reassign class ConvModel1(Module): def __init__(self): super().__init__() self.conv1 = Conv2d(3, 128, 3, bias=False) self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False) def forward(self, inputs): pass ground_truth = ( "ConvModel1(\n" " (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" " (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" ")" ) net = ConvModel1() output = net.__repr__() assert output == ground_truth def test_repr_module_rereference(): # test whether __repr__ can deal with module re-reference class ConvModel2(Module): def __init__(self): super().__init__() self.conv1 = Conv2d(3, 128, 3, bias=False) self.conv2 = self.conv1 self.conv3 = self.conv1 def forward(self, inputs): pass ground_truth = ( "ConvModel2(\n" " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" " (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" " (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" ")" ) net = ConvModel2() output = net.__repr__() assert output == ground_truth def test_repr_module_delete(): # test whether __repr__ can deal with module delete class ConvModel3(Module): def __init__(self): super().__init__() self.conv1 = Conv2d(3, 128, 3, bias=False) self.softmax = Softmax(100) def forward(self, inputs): pass ground_truth = ( "ConvModel3(\n" " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" ")" ) net = ConvModel3() del net.softmax output = net.__repr__() assert output == ground_truth def test_repr_module_reset_attr(): class ResetAttrModule(Module): def __init__(self, flag): super().__init__() if flag: self.a = None self.a = Linear(3, 5) else: self.a = Linear(3, 5) self.a = None def forward(self, x): if self.a: x = self.a(x) return x ground_truth = [ ( "ResetAttrModule(\n" " (a): Linear(in_features=3, out_features=5, bias=True)\n" ")" ), ("ResetAttrModule()"), ] m0 = ResetAttrModule(True) m1 = ResetAttrModule(False) output = [m0.__repr__(), m1.__repr__()] assert output == ground_truth def test_module_compatible(): class Empty(Module): def forward(self): pass empty_module = Empty() old_attributes = set( [ "_modules", "name", "training", "quantize_disabled", "_forward_pre_hooks", "_forward_hooks", "_name", "_short_name", ] ) current_attributes = set(empty_module.__dict__.keys()) assert ( old_attributes == current_attributes ), "Add or delete attributes in Module class may break compatibility of pickle serialization"