From 0b8793184695cc9b352140a5679fa8cf5ac77b99 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 22 Jun 2022 10:08:05 +0800 Subject: [PATCH] [FIx bug]layer to 'NoneType' object has no attribute 'place' (#43597) (#43717) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit bug: 当class Layer的_buffers中有参数为None的时候,调用to()方法将会报layer to 'NoneType' object has no attribute 'place'的错误。 修复方法: to()方法增加对_buffers中None类型参数的判断,如果为None,跳过该参数的处理。 --- python/paddle/fluid/dygraph/layers.py | 3 ++- python/paddle/fluid/tests/unittests/test_base_layer.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 088fed03c35..6392b0d1151 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1582,7 +1582,8 @@ class Layer(object): blocking) for key, buf in self._buffers.items(): - self._buffers[key] = func(buf, device, dtype, blocking) + if buf is not None: + self._buffers[key] = func(buf, device, dtype, blocking) self._dtype = dtype diff --git a/python/paddle/fluid/tests/unittests/test_base_layer.py b/python/paddle/fluid/tests/unittests/test_base_layer.py index 3bdd03b3212..ab5ce774692 100644 --- a/python/paddle/fluid/tests/unittests/test_base_layer.py +++ b/python/paddle/fluid/tests/unittests/test_base_layer.py @@ -544,16 +544,25 @@ class TestLayerTo(unittest.TestCase): else: self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) + def func_test_to_api_none_buffer(self): + model = paddle.nn.Linear(2, 4) + buffer = None + model.register_buffer("buf_name", buffer, persistable=True) + model.to(dtype='float64') + self.assertEqual(model._buffers['buf_name'], None) + def test_main(self): with _test_eager_guard(): self.funcsetUp() self.func_test_to_api() self.func_test_to_api_paddle_dtype() self.func_test_to_api_numpy_dtype() + self.func_test_to_api_none_buffer() self.funcsetUp() self.func_test_to_api() self.func_test_to_api_paddle_dtype() self.func_test_to_api_numpy_dtype() + self.func_test_to_api_none_buffer() if __name__ == '__main__': -- GitLab