未验证 提交 4d32f417 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support EagerParamBase init by 'shape'(Tensor) (#43045)

上级 6d78524c
...@@ -6619,6 +6619,9 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -6619,6 +6619,9 @@ class EagerParamBase(_core_eager_eagertensor):
name = kwargs.get('name', unique_name.generate('_eager_param_base')) name = kwargs.get('name', unique_name.generate('_eager_param_base'))
if isinstance(shape, core.eager.Tensor):
shape = shape.numpy()
super(EagerParamBase, self).__init__( super(EagerParamBase, self).__init__(
dtype if dtype else core.VarDesc.VarType.FP32, dtype if dtype else core.VarDesc.VarType.FP32,
list(shape) list(shape)
......
...@@ -279,6 +279,16 @@ class EagerVariablePropertiesAndMethodsTestCase(unittest.TestCase): ...@@ -279,6 +279,16 @@ class EagerVariablePropertiesAndMethodsTestCase(unittest.TestCase):
"The type of trainable MUST be bool, but the type is /*"): "The type of trainable MUST be bool, but the type is /*"):
eager_param.trainable = "False" eager_param.trainable = "False"
eager_param_2 = EagerParamBase(
shape=paddle.shape(paddle.to_tensor([1, 2, 3, 4])), dtype="float32")
self.assertTrue(eager_param_2.trainable)
eager_param_2.trainable = False
self.assertFalse(eager_param_2.trainable)
with self.assertRaisesRegexp(
ValueError,
"The type of trainable MUST be bool, but the type is /*"):
eager_param_2.trainable = "False"
def test_constructor(self): def test_constructor(self):
print("Test_constructor") print("Test_constructor")
paddle.set_device("cpu") paddle.set_device("cpu")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册