From cb36478a3662ec52fa7b59955d9738acd84cf2a7 Mon Sep 17 00:00:00 2001 From: songyouwei Date: Fri, 10 Apr 2020 14:29:49 +0800 Subject: [PATCH] add LayerList insert and extend (#23377) * add LayerList insert and extend test=develop * add index range check test=develop * add sample codes test=develop * refine sample code test=develop --- python/paddle/fluid/dygraph/container.py | 59 +++++++++++++++++++ ...=> test_imperative_container_layerlist.py} | 18 +++++- 2 files changed, 75 insertions(+), 2 deletions(-) rename python/paddle/fluid/tests/unittests/{test_imperative_layerlist.py => test_imperative_container_layerlist.py} (76%) diff --git a/python/paddle/fluid/dygraph/container.py b/python/paddle/fluid/dygraph/container.py index 9151e8e04e..8a8787da3a 100644 --- a/python/paddle/fluid/dygraph/container.py +++ b/python/paddle/fluid/dygraph/container.py @@ -236,6 +236,65 @@ class LayerList(Layer): Parameters: sublayer (Layer): sublayer to append + + Examples: + .. code-block:: python + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + linears = fluid.dygraph.LayerList([fluid.dygraph.Linear(10, 10) for i in range(10)]) + another = fluid.dygraph.Linear(10, 10) + linears.append(another) + print(len(linears)) # 11 """ self.add_sublayer(str(len(self)), sublayer) return self + + def insert(self, index, sublayer): + """ + Insert a sublayer before a given index in the list. + + Parameters: + index (int): index to insert. + sublayer (Layer): sublayer to insert + + Examples: + .. code-block:: python + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + linears = fluid.dygraph.LayerList([fluid.dygraph.Linear(10, 10) for i in range(10)]) + another = fluid.dygraph.Linear(10, 10) + linears.insert(3, another) + print(linears[3] is another) # True + """ + assert isinstance(index, int) and \ + 0 <= index < len(self._sub_layers), \ + "index should be an integer in range [0, len(self))" + for i in range(len(self._sub_layers), index, -1): + self._sub_layers[str(i)] = self._sub_layers[str(i - 1)] + self._sub_layers[str(index)] = sublayer + + def extend(self, sublayers): + """ + Appends sublayers to the end of the list. + + Parameters: + sublayers (iterable of Layer): iterable of sublayers to append + + Examples: + .. code-block:: python + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + linears = fluid.dygraph.LayerList([fluid.dygraph.Linear(10, 10) for i in range(10)]) + another_list = fluid.dygraph.LayerList([fluid.dygraph.Linear(10, 10) for i in range(5)]) + linears.extend(another_list) + print(len(linears)) # 15 + print(another_list[0] is linears[10]) # True + """ + offset = len(self) + for i, sublayer in enumerate(sublayers): + idx = str(offset + i) + self.add_sublayer(idx, sublayer) + return self diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layerlist.py b/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py similarity index 76% rename from python/paddle/fluid/tests/unittests/test_imperative_layerlist.py rename to python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py index 57509692fc..610f50603d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layerlist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_container_layerlist.py @@ -30,8 +30,8 @@ class MyLayer(fluid.Layer): return x -class TestImperativeContainerParameterList(unittest.TestCase): - def test_paramter_list(self): +class TestImperativeContainer(unittest.TestCase): + def test_layer_list(self): data_np = np.random.uniform(-1, 1, [5, 1]).astype('float32') with fluid.dygraph.guard(): x = fluid.dygraph.to_variable(data_np) @@ -61,6 +61,20 @@ class TestImperativeContainerParameterList(unittest.TestCase): self.assertListEqual(res6.shape, [5, 2**(0 + 1)]) res6.backward() + model3 = MyLayer(layerlist[:-2]) + model3.layerlist.append(fluid.dygraph.Linear(3, 1)) + model3.layerlist.insert(size - 2, + fluid.dygraph.Linear(2**(size - 2), 3)) + res7 = model3(x) + self.assertListEqual(res7.shape, [5, 1]) + to_be_extended = [ + fluid.dygraph.Linear(3**i, 3**(i + 1)) for i in range(3) + ] + model3.layerlist.extend(to_be_extended) + res8 = model3(x) + self.assertListEqual(res8.shape, [5, 3**3]) + res8.backward() + if __name__ == '__main__': unittest.main() -- GitLab