提交 82692159 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/module): fix named_children of Sequential

GitOrigin-RevId: d3220fb361f018042f5c5a8d085e037397e7ecef
上级 eed54081
......@@ -19,7 +19,7 @@ class Sequential(Module):
To make it easier to understand, here is a small example:
.. testcode::
from collections import OrderedDict
import numpy as np
import megengine.nn as nn
import megengine.nn.functional as F
......@@ -29,34 +29,35 @@ class Sequential(Module):
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,))
data = data.reshape(batch_size, -1)
net = nn.Sequential(
net0 = nn.Sequential(
nn.Linear(28 * 28, 320),
nn.Linear(320, 500),
nn.Linear(500, 320),
nn.Linear(320, 10)
)
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
pred0 = net0(data)
modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10)
net1 = nn.Sequential(modules)
pred1 = net1(data)
"""
def __init__(self, *args):
super().__init__()
self.layer_keys = []
self.layer_values = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
# self.add_module(key, module)
setattr(self, key, module)
self.layer_keys.append(key)
self.layer_values.append(module)
else:
for idx, module in enumerate(args):
# self.add_module(str(idx), module)
setattr(self, str(idx), module)
self.layer_keys.append(str(idx))
self.layer_values.append(module)
def __getitem__(self, idx):
if isinstance(idx, slice):
......@@ -64,11 +65,10 @@ class Sequential(Module):
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx]))
)
else:
return self.layer_values[idx]
return getattr(self, self.layer_keys[idx])
def __setitem__(self, idx, module):
key = self.layer_keys[idx]
self.layer_values[idx] = module
return setattr(self, key, module)
def __delitem__(self, idx):
......@@ -76,11 +76,9 @@ class Sequential(Module):
for key in self.layer_keys[idx]:
delattr(self, key)
del self.layer_keys[idx]
del self.layer_values[idx]
else:
delattr(self, self.layer_keys[idx])
del self.layer_keys[idx]
del self.layer_values[idx]
def __len__(self):
return len(self.layer_keys)
......@@ -88,6 +86,10 @@ class Sequential(Module):
def __iter__(self):
return iter(self.layer_values)
@property
def layer_values(self):
return [getattr(self, key) for key in self.layer_keys]
def forward(self, inp):
for layer in self.layer_values:
inp = layer(inp)
......
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import tempfile
from collections import OrderedDict
from io import BytesIO
import numpy as np
......@@ -16,7 +17,14 @@ from helpers import MLP
import megengine as mge
import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Linear,
Module,
Sequential,
)
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose
......@@ -238,6 +246,18 @@ def test_module_api_with_sequential():
]
def test_sequential_named_children():
modules = OrderedDict()
modules["name0"] = Linear(20, 10)
modules["name1"] = Linear(10, 5)
modules["name2"] = Linear(5, 1)
m = Sequential(modules)
l = list(m.named_children())
assert l[0][0] == "name0"
assert l[1][0] == "name1"
assert l[2][0] == "name2"
def test_state_dict():
data_shape = (2, 28)
data = tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册