From 4d32f417a435446d06541ae951edc2404e97e74c Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 27 May 2022 21:36:35 +0800 Subject: [PATCH] [Eager] Support EagerParamBase init by 'shape'(Tensor) (#43045) --- python/paddle/fluid/framework.py | 3 +++ .../fluid/tests/unittests/test_egr_python_api.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 6957dd8c5e..757b1a2da9 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6619,6 +6619,9 @@ class EagerParamBase(_core_eager_eagertensor): name = kwargs.get('name', unique_name.generate('_eager_param_base')) + if isinstance(shape, core.eager.Tensor): + shape = shape.numpy() + super(EagerParamBase, self).__init__( dtype if dtype else core.VarDesc.VarType.FP32, list(shape) diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index 600a49b233..bb8c6346eb 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -279,6 +279,16 @@ class EagerVariablePropertiesAndMethodsTestCase(unittest.TestCase): "The type of trainable MUST be bool, but the type is /*"): eager_param.trainable = "False" + eager_param_2 = EagerParamBase( + shape=paddle.shape(paddle.to_tensor([1, 2, 3, 4])), dtype="float32") + self.assertTrue(eager_param_2.trainable) + eager_param_2.trainable = False + self.assertFalse(eager_param_2.trainable) + with self.assertRaisesRegexp( + ValueError, + "The type of trainable MUST be bool, but the type is /*"): + eager_param_2.trainable = "False" + def test_constructor(self): print("Test_constructor") paddle.set_device("cpu") -- GitLab