未验证 提交 a1a2054e 编写于 作者: Z zhangbo9674 提交者: GitHub

[FIx bug]layer to 'NoneType' object has no attribute 'place' (#43597)

* refine layer to

* refine code

* add ut
上级 c701e114
...@@ -1594,7 +1594,8 @@ class Layer(object): ...@@ -1594,7 +1594,8 @@ class Layer(object):
blocking) blocking)
for key, buf in self._buffers.items(): for key, buf in self._buffers.items():
self._buffers[key] = func(buf, device, dtype, blocking) if buf is not None:
self._buffers[key] = func(buf, device, dtype, blocking)
self._dtype = dtype self._dtype = dtype
......
...@@ -559,16 +559,25 @@ class TestLayerTo(unittest.TestCase): ...@@ -559,16 +559,25 @@ class TestLayerTo(unittest.TestCase):
else: else:
self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase)) self.assertTrue(isinstance(p, paddle.fluid.framework.ParamBase))
def func_test_to_api_none_buffer(self):
model = paddle.nn.Linear(2, 4)
buffer = None
model.register_buffer("buf_name", buffer, persistable=True)
model.to(dtype='float64')
self.assertEqual(model._buffers['buf_name'], None)
def test_main(self): def test_main(self):
with _test_eager_guard(): with _test_eager_guard():
self.funcsetUp() self.funcsetUp()
self.func_test_to_api() self.func_test_to_api()
self.func_test_to_api_paddle_dtype() self.func_test_to_api_paddle_dtype()
self.func_test_to_api_numpy_dtype() self.func_test_to_api_numpy_dtype()
self.func_test_to_api_none_buffer()
self.funcsetUp() self.funcsetUp()
self.func_test_to_api() self.func_test_to_api()
self.func_test_to_api_paddle_dtype() self.func_test_to_api_paddle_dtype()
self.func_test_to_api_numpy_dtype() self.func_test_to_api_numpy_dtype()
self.func_test_to_api_none_buffer()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册