From 4f57da5fa6866a81f47ba90a8c9573648bdff11d Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Thu, 17 Nov 2022 11:19:58 +0800 Subject: [PATCH] [Zero-Dim] temporarily revert create_scalar due to input 0D is not fully supported (#48058) --- python/paddle/fluid/layers/math_op_patch.py | 3 ++- python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index fb397943434..f9ba6498671 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -99,7 +99,8 @@ def monkey_patch_variable(): return var def create_scalar(block, value, dtype): - return create_tensor(block, value, dtype, shape=[]) + # TODO(zhouwei): will change to [] which is 0-D Tensor + return create_tensor(block, value, dtype, shape=[1]) def create_tensor_with_batchsize(ref_var, value, dtype): assert isinstance(ref_var, Variable) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index c85f5aec42e..174172b026f 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -350,7 +350,7 @@ class TestBinaryAPI(unittest.TestCase): paddle.enable_static() - def test_static_unary(self): + def test_static_binary(self): paddle.enable_static() for api in binary_api_list + binary_api_list_without_grad: main_prog = fluid.Program() @@ -377,15 +377,19 @@ class TestBinaryAPI(unittest.TestCase): # Test runtime shape self.assertEqual(out_np.shape, ()) + # TODO(zhouwei): will open when create_scalar is [] # 2) x is 0D , y is scalar + ''' x = paddle.rand([]) y = 0.5 x.stop_gradient = False + print(api) if isinstance(api, dict): out = getattr(paddle.static.Variable, api['cls_method'])( x, y ) self.assertEqual(out.shape, ()) + ''' for api in binary_int_api_list_without_grad: main_prog = fluid.Program() -- GitLab