From 823f499a8ad374da79564849786e1a3757425468 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Sun, 7 Feb 2021 10:35:25 +0800 Subject: [PATCH] fix a bug of Sequential::__getitem__ (#30899) * fix a bug of Sequential::__getitem__, test=develop * add testcase, test=develop --- python/paddle/fluid/dygraph/container.py | 11 ++++- python/paddle/fluid/tests/test_sequential.py | 43 ++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/test_sequential.py diff --git a/python/paddle/fluid/dygraph/container.py b/python/paddle/fluid/dygraph/container.py index bfcb43f5f67..dd04b107204 100644 --- a/python/paddle/fluid/dygraph/container.py +++ b/python/paddle/fluid/dygraph/container.py @@ -67,7 +67,16 @@ class Sequential(Layer): self.add_sublayer(str(idx), layer) def __getitem__(self, name): - return self._sub_layers[str(name)] + if isinstance(name, slice): + return self.__class__(*(list(self._sub_layers.values())[name])) + else: + if name >= len(self._sub_layers): + raise IndexError('index {} is out of range'.format(name)) + elif name < 0 and name >= -len(self._sub_layers): + name += len(self._sub_layers) + elif name < -len(self._sub_layers): + raise IndexError('index {} is out of range'.format(name)) + return self._sub_layers[str(name)] def __setitem__(self, name, layer): assert isinstance(layer, Layer) diff --git a/python/paddle/fluid/tests/test_sequential.py b/python/paddle/fluid/tests/test_sequential.py new file mode 100644 index 00000000000..7446bb83841 --- /dev/null +++ b/python/paddle/fluid/tests/test_sequential.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle + + +class TestDataFeeder(unittest.TestCase): + def test_lod_level_1_converter(self): + sequential = paddle.nn.Sequential() + + for i in range(10): + sequential.add_sublayer(str(i), paddle.nn.Linear(i + 1, i + 1)) + + for item in sequential: + tmp = item + + tmp = sequential[3:5] + self.assertEqual(len(tmp), 2) + + tmp = sequential[-1] + self.assertEqual(tmp, sequential[9]) + + with self.assertRaises(IndexError): + tmp = sequential[10] + + with self.assertRaises(IndexError): + tmp = sequential[-11] + + +if __name__ == '__main__': + unittest.main() -- GitLab