未验证 提交 5394194e 编写于 作者: W Wenyu 提交者: GitHub

support minus-int idx to LayerList (#31750)

* support minus-int idx to LayerList
* update layerlist test
上级 ef8323d4
...@@ -213,13 +213,25 @@ class LayerList(Layer): ...@@ -213,13 +213,25 @@ class LayerList(Layer):
for idx, layer in enumerate(sublayers): for idx, layer in enumerate(sublayers):
self.add_sublayer(str(idx), layer) self.add_sublayer(str(idx), layer)
def _get_abs_idx(self, idx):
if isinstance(idx, int):
if not (-len(self) <= idx < len(self)):
raise IndexError(
'index {} is out of range, should be an integer in range [{}, {})'.
format(idx, -len(self), len(self)))
if idx < 0:
idx += len(self)
return idx
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return self.__class__(list(self._sub_layers.values())[idx]) return self.__class__(list(self._sub_layers.values())[idx])
else: else:
idx = self._get_abs_idx(idx)
return self._sub_layers[str(idx)] return self._sub_layers[str(idx)]
def __setitem__(self, idx, sublayer): def __setitem__(self, idx, sublayer):
idx = self._get_abs_idx(idx)
return setattr(self, str(idx), sublayer) return setattr(self, str(idx), sublayer)
def __delitem__(self, idx): def __delitem__(self, idx):
...@@ -227,6 +239,7 @@ class LayerList(Layer): ...@@ -227,6 +239,7 @@ class LayerList(Layer):
for k in range(len(self._sub_layers))[idx]: for k in range(len(self._sub_layers))[idx]:
delattr(self, str(k)) delattr(self, str(k))
else: else:
idx = self._get_abs_idx(idx)
delattr(self, str(idx)) delattr(self, str(idx))
str_indices = [str(i) for i in range(len(self._sub_layers))] str_indices = [str(i) for i in range(len(self._sub_layers))]
self._sub_layers = OrderedDict( self._sub_layers = OrderedDict(
...@@ -275,10 +288,15 @@ class LayerList(Layer): ...@@ -275,10 +288,15 @@ class LayerList(Layer):
another = paddle.nn.Linear(10, 10) another = paddle.nn.Linear(10, 10)
linears.insert(3, another) linears.insert(3, another)
print(linears[3] is another) # True print(linears[3] is another) # True
another = paddle.nn.Linear(10, 10)
linears.insert(-1, another)
print(linears[-2] is another) # True
""" """
assert isinstance(index, int) and \ assert isinstance(index, int) and \
0 <= index < len(self._sub_layers), \ -len(self._sub_layers) <= index < len(self._sub_layers), \
"index should be an integer in range [0, len(self))" "index should be an integer in range [{}, {})".format(-len(self), len(self))
index = self._get_abs_idx(index)
for i in range(len(self._sub_layers), index, -1): for i in range(len(self._sub_layers), index, -1):
self._sub_layers[str(i)] = self._sub_layers[str(i - 1)] self._sub_layers[str(i)] = self._sub_layers[str(i - 1)]
self._sub_layers[str(index)] = sublayer self._sub_layers[str(index)] = sublayer
......
...@@ -84,6 +84,18 @@ class TestImperativeContainer(unittest.TestCase): ...@@ -84,6 +84,18 @@ class TestImperativeContainer(unittest.TestCase):
self.assertListEqual(res8.shape, [5, 3**3]) self.assertListEqual(res8.shape, [5, 3**3])
res8.backward() res8.backward()
model4 = MyLayer(layerlist[:3])
model4.layerlist[-1] = fluid.dygraph.Linear(4, 5)
res9 = model4(x)
self.assertListEqual(res9.shape, [5, 5])
del model4.layerlist[-1]
res10 = model4(x)
self.assertListEqual(res10.shape, [5, 4])
model4.layerlist.insert(-1, fluid.dygraph.Linear(2, 2))
res11 = model4(x)
self.assertListEqual(res11.shape, [5, 4])
res11.backward()
def test_layer_list(self): def test_layer_list(self):
self.layer_list(True) self.layer_list(True)
self.layer_list(False) self.layer_list(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册