diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 6957dd8c5e30c428904940e5976be1143d75fbb9..757b1a2da95b99ce90faf82fa3f4f9eeef98ab39 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6619,6 +6619,9 @@ class EagerParamBase(_core_eager_eagertensor): name = kwargs.get('name', unique_name.generate('_eager_param_base')) + if isinstance(shape, core.eager.Tensor): + shape = shape.numpy() + super(EagerParamBase, self).__init__( dtype if dtype else core.VarDesc.VarType.FP32, list(shape) diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index 600a49b2332beed17954803e87cf57e6cd0bdbfc..bb8c6346eb5a54f4ab67b54f9be74df9603da108 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -279,6 +279,16 @@ class EagerVariablePropertiesAndMethodsTestCase(unittest.TestCase): "The type of trainable MUST be bool, but the type is /*"): eager_param.trainable = "False" + eager_param_2 = EagerParamBase( + shape=paddle.shape(paddle.to_tensor([1, 2, 3, 4])), dtype="float32") + self.assertTrue(eager_param_2.trainable) + eager_param_2.trainable = False + self.assertFalse(eager_param_2.trainable) + with self.assertRaisesRegexp( + ValueError, + "The type of trainable MUST be bool, but the type is /*"): + eager_param_2.trainable = "False" + def test_constructor(self): print("Test_constructor") paddle.set_device("cpu")