提交 82692159 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/module): fix named_children of Sequential

GitOrigin-RevId: d3220fb361f018042f5c5a8d085e037397e7ecef
上级 eed54081
...@@ -19,7 +19,7 @@ class Sequential(Module): ...@@ -19,7 +19,7 @@ class Sequential(Module):
To make it easier to understand, here is a small example: To make it easier to understand, here is a small example:
.. testcode:: .. testcode::
from collections import OrderedDict
import numpy as np import numpy as np
import megengine.nn as nn import megengine.nn as nn
import megengine.nn.functional as F import megengine.nn.functional as F
...@@ -29,34 +29,35 @@ class Sequential(Module): ...@@ -29,34 +29,35 @@ class Sequential(Module):
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,)) label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,))
data = data.reshape(batch_size, -1) data = data.reshape(batch_size, -1)
net = nn.Sequential(
net0 = nn.Sequential(
nn.Linear(28 * 28, 320), nn.Linear(28 * 28, 320),
nn.Linear(320, 500),
nn.Linear(500, 320),
nn.Linear(320, 10) nn.Linear(320, 10)
) )
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) pred0 = net0(data)
modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10)
net1 = nn.Sequential(modules)
pred1 = net1(data)
""" """
def __init__(self, *args): def __init__(self, *args):
super().__init__() super().__init__()
self.layer_keys = [] self.layer_keys = []
self.layer_values = []
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items(): for key, module in args[0].items():
# self.add_module(key, module) # self.add_module(key, module)
setattr(self, key, module) setattr(self, key, module)
self.layer_keys.append(key) self.layer_keys.append(key)
self.layer_values.append(module)
else: else:
for idx, module in enumerate(args): for idx, module in enumerate(args):
# self.add_module(str(idx), module) # self.add_module(str(idx), module)
setattr(self, str(idx), module) setattr(self, str(idx), module)
self.layer_keys.append(str(idx)) self.layer_keys.append(str(idx))
self.layer_values.append(module)
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
...@@ -64,11 +65,10 @@ class Sequential(Module): ...@@ -64,11 +65,10 @@ class Sequential(Module):
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx])) OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx]))
) )
else: else:
return self.layer_values[idx] return getattr(self, self.layer_keys[idx])
def __setitem__(self, idx, module): def __setitem__(self, idx, module):
key = self.layer_keys[idx] key = self.layer_keys[idx]
self.layer_values[idx] = module
return setattr(self, key, module) return setattr(self, key, module)
def __delitem__(self, idx): def __delitem__(self, idx):
...@@ -76,11 +76,9 @@ class Sequential(Module): ...@@ -76,11 +76,9 @@ class Sequential(Module):
for key in self.layer_keys[idx]: for key in self.layer_keys[idx]:
delattr(self, key) delattr(self, key)
del self.layer_keys[idx] del self.layer_keys[idx]
del self.layer_values[idx]
else: else:
delattr(self, self.layer_keys[idx]) delattr(self, self.layer_keys[idx])
del self.layer_keys[idx] del self.layer_keys[idx]
del self.layer_values[idx]
def __len__(self): def __len__(self):
return len(self.layer_keys) return len(self.layer_keys)
...@@ -88,6 +86,10 @@ class Sequential(Module): ...@@ -88,6 +86,10 @@ class Sequential(Module):
def __iter__(self): def __iter__(self):
return iter(self.layer_values) return iter(self.layer_values)
@property
def layer_values(self):
return [getattr(self, key) for key in self.layer_keys]
def forward(self, inp): def forward(self, inp):
for layer in self.layer_values: for layer in self.layer_values:
inp = layer(inp) inp = layer(inp)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import tempfile import tempfile
from collections import OrderedDict
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
...@@ -16,7 +17,14 @@ from helpers import MLP ...@@ -16,7 +17,14 @@ from helpers import MLP
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Linear,
Module,
Sequential,
)
from megengine.quantization.quantize import quantize, quantize_qat from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -238,6 +246,18 @@ def test_module_api_with_sequential(): ...@@ -238,6 +246,18 @@ def test_module_api_with_sequential():
] ]
def test_sequential_named_children():
modules = OrderedDict()
modules["name0"] = Linear(20, 10)
modules["name1"] = Linear(10, 5)
modules["name2"] = Linear(5, 1)
m = Sequential(modules)
l = list(m.named_children())
assert l[0][0] == "name0"
assert l[1][0] == "name1"
assert l[2][0] == "name2"
def test_state_dict(): def test_state_dict():
data_shape = (2, 28) data_shape = (2, 28)
data = tensor() data = tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册