diff --git a/python_module/megengine/module/sequential.py b/python_module/megengine/module/sequential.py index 03afd48a7e3f0b4012e2fd59e2b6bff4d66b602f..42d43d84467f967b699cb1b055ca211f33659c4a 100644 --- a/python_module/megengine/module/sequential.py +++ b/python_module/megengine/module/sequential.py @@ -19,7 +19,7 @@ class Sequential(Module): To make it easier to understand, here is a small example: .. testcode:: - + from collections import OrderedDict import numpy as np import megengine.nn as nn import megengine.nn.functional as F @@ -29,34 +29,35 @@ class Sequential(Module): label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,)) data = data.reshape(batch_size, -1) - net = nn.Sequential( + + net0 = nn.Sequential( nn.Linear(28 * 28, 320), - nn.Linear(320, 500), - nn.Linear(500, 320), nn.Linear(320, 10) ) - pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) + pred0 = net0(data) + modules = OrderedDict() + modules["fc0"] = nn.Linear(28 * 28, 320) + modules["fc1"] = nn.Linear(320, 10) + net1 = nn.Sequential(modules) + + pred1 = net1(data) """ def __init__(self, *args): super().__init__() self.layer_keys = [] - self.layer_values = [] if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): # self.add_module(key, module) setattr(self, key, module) self.layer_keys.append(key) - self.layer_values.append(module) else: for idx, module in enumerate(args): # self.add_module(str(idx), module) setattr(self, str(idx), module) self.layer_keys.append(str(idx)) - self.layer_values.append(module) def __getitem__(self, idx): if isinstance(idx, slice): @@ -64,11 +65,10 @@ class Sequential(Module): OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx])) ) else: - return self.layer_values[idx] + return getattr(self, self.layer_keys[idx]) def __setitem__(self, idx, module): key = self.layer_keys[idx] - self.layer_values[idx] = module return setattr(self, key, module) def __delitem__(self, idx): @@ -76,11 +76,9 @@ class Sequential(Module): for key in self.layer_keys[idx]: delattr(self, key) del self.layer_keys[idx] - del self.layer_values[idx] else: delattr(self, self.layer_keys[idx]) del self.layer_keys[idx] - del self.layer_values[idx] def __len__(self): return len(self.layer_keys) @@ -88,6 +86,10 @@ class Sequential(Module): def __iter__(self): return iter(self.layer_values) + @property + def layer_values(self): + return [getattr(self, key) for key in self.layer_keys] + def forward(self, inp): for layer in self.layer_values: inp = layer(inp) diff --git a/python_module/test/unit/module/test_module.py b/python_module/test/unit/module/test_module.py index 3790d2c4361761b3cf7798ac85ea3a89821832e8..1b0a194ebb94c59591f8937c7244c7ba6aab1d5c 100644 --- a/python_module/test/unit/module/test_module.py +++ b/python_module/test/unit/module/test_module.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import tempfile +from collections import OrderedDict from io import BytesIO import numpy as np @@ -16,7 +17,14 @@ from helpers import MLP import megengine as mge import megengine._internal as mgb from megengine.core import Buffer, Parameter, Tensor, tensor -from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential +from megengine.module import ( + BatchNorm1d, + BatchNorm2d, + Conv2d, + Linear, + Module, + Sequential, +) from megengine.quantization.quantize import quantize, quantize_qat from megengine.test import assertTensorClose @@ -238,6 +246,18 @@ def test_module_api_with_sequential(): ] +def test_sequential_named_children(): + modules = OrderedDict() + modules["name0"] = Linear(20, 10) + modules["name1"] = Linear(10, 5) + modules["name2"] = Linear(5, 1) + m = Sequential(modules) + l = list(m.named_children()) + assert l[0][0] == "name0" + assert l[1][0] == "name1" + assert l[2][0] == "name2" + + def test_state_dict(): data_shape = (2, 28) data = tensor()