From 5394194e3ab5eb851fea5e5d50a4e49a1d596e8b Mon Sep 17 00:00:00 2001 From: Wenyu Date: Wed, 31 Mar 2021 10:40:51 +0800 Subject: [PATCH] support minus-int idx to LayerList (#31750) * support minus-int idx to LayerList * update layerlist test --- python/paddle/fluid/dygraph/container.py | 22 +++++++++++++++++-- .../test_imperative_container_layerlist.py | 12 ++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/container.py b/python/paddle/fluid/dygraph/container.py index dd04b107204..e80bc1245f9 100644 --- a/python/paddle/fluid/dygraph/container.py +++ b/python/paddle/fluid/dygraph/container.py @@ -213,13 +213,25 @@ class LayerList(Layer): for idx, layer in enumerate(sublayers): self.add_sublayer(str(idx), layer) + def _get_abs_idx(self, idx): + if isinstance(idx, int): + if not (-len(self) <= idx < len(self)): + raise IndexError( + 'index {} is out of range, should be an integer in range [{}, {})'. + format(idx, -len(self), len(self))) + if idx < 0: + idx += len(self) + return idx + def __getitem__(self, idx): if isinstance(idx, slice): return self.__class__(list(self._sub_layers.values())[idx]) else: + idx = self._get_abs_idx(idx) return self._sub_layers[str(idx)] def __setitem__(self, idx, sublayer): + idx = self._get_abs_idx(idx) return setattr(self, str(idx), sublayer) def __delitem__(self, idx): @@ -227,6 +239,7 @@ class LayerList(Layer): for k in range(len(self._sub_layers))[idx]: delattr(self, str(k)) else: + idx = self._get_abs_idx(idx) delattr(self, str(idx)) str_indices = [str(i) for i in range(len(self._sub_layers))] self._sub_layers = OrderedDict( @@ -275,10 +288,15 @@ class LayerList(Layer): another = paddle.nn.Linear(10, 10) linears.insert(3, another) print(linears[3] is another) # True + another = paddle.nn.Linear(10, 10) + linears.insert(-1, another) + print(linears[-2] is another) # True """ assert isinstance(index, int) and \ - 0 <= index < len(self._sub_layers), \ - "index should be an integer in range [0, len(self))" + -len(self._sub_layers) <= index < len(self._sub_layers), \ + "index should be an integer in range [{}, {})".format(-len(self), len(self)) + + index = self._get_abs_idx(index) for i in range(len(self._sub_layers), index, -1): self._sub_layers[str(i)] = self._sub_layers[str(i - 1)] self._sub_layers[str(index)] = sublayer diff --git a/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py b/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py index ef90dd04986..2e722b69c3e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py @@ -84,6 +84,18 @@ class TestImperativeContainer(unittest.TestCase): self.assertListEqual(res8.shape, [5, 3**3]) res8.backward() + model4 = MyLayer(layerlist[:3]) + model4.layerlist[-1] = fluid.dygraph.Linear(4, 5) + res9 = model4(x) + self.assertListEqual(res9.shape, [5, 5]) + del model4.layerlist[-1] + res10 = model4(x) + self.assertListEqual(res10.shape, [5, 4]) + model4.layerlist.insert(-1, fluid.dygraph.Linear(2, 2)) + res11 = model4(x) + self.assertListEqual(res11.shape, [5, 4]) + res11.backward() + def test_layer_list(self): self.layer_list(True) self.layer_list(False) -- GitLab