diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c6e7dda9699baa47c47b8a81476ea5f211443ba0..e1b1e61c23cf20df5233871c81afa508bbc06e65 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3762,6 +3762,9 @@ void SqueezeInferMeta(const MetaTensor& x, if (!config.is_runtime && axes.FromTensor()) { // compile time infershape, set all elements to -1. int output_size = x.dims().size() - axes.GetData().size(); + if (x.dims().size() == 0 && output_size == -1) { + output_size = 0; + } std::vector vec_out_dims(output_size, -1); out->set_dims(phi::make_ddim(vec_out_dims)); } else { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 37718c25c6c1f15915dfddcabbe9e8ffdb158cb4..1b45c73973ba34eecf1d6c191333818847bf1402 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -580,7 +580,7 @@ def _as_lodtensor(data, place, dtype=None): else dtype ) if np.isscalar(data): - data = np.array([data]).astype(dtype) + data = np.array(data).astype(dtype) elif isinstance(data, (list, tuple)): data = np.array(data) if data.dtype == np.object_: diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 9765ebda405b80d391460b6b6b0169ea3a01626d..0b8c4aa8eae4215545a3953fe96a05040e2f557a 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -96,7 +96,7 @@ def create_paddle_case(op_type, callback): paddle.enable_static() with program_guard(Program(), Program()): x = paddle.static.data(name='x', shape=[4], dtype='int64') - y = paddle.static.data(name='y', shape=[1], dtype='int64') + y = paddle.static.data(name='y', shape=[], dtype='int64') op = eval("paddle.%s" % (self.op_type)) out = op(x, y) exe = fluid.Executor(self.place) diff --git a/python/paddle/fluid/tests/unittests/test_data.py b/python/paddle/fluid/tests/unittests/test_data.py index 51f01ed8bef3303e5d397afe9222345a9b7f656e..763637fd0edb7d1c0df5cfb87fc0c494e2864a59 100644 --- a/python/paddle/fluid/tests/unittests/test_data.py +++ b/python/paddle/fluid/tests/unittests/test_data.py @@ -31,6 +31,13 @@ class TestApiStaticDataError(unittest.TestCase): x3 = paddle.static.data(name="x3", shape=[2, 25]) self.assertEqual(x3.dtype, core.VarDesc.VarType.FP64) + def test_0D(self): + with program_guard(Program(), Program()): + x1 = paddle.static.data(name="x1_0D", shape=[]) + self.assertEqual(x1.dtype, core.VarDesc.VarType.FP32) + x2 = paddle.static.data(name="x2_0D", shape=(), dtype="bool") + self.assertEqual(x2.dtype, core.VarDesc.VarType.BOOL) + def test_error(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_deg2rad.py b/python/paddle/fluid/tests/unittests/test_deg2rad.py index ad9e214cb3f5016604899d8a0287092b44e0f914..0f038e86f2522cdf4208e6515c8e48484dcb331a 100644 --- a/python/paddle/fluid/tests/unittests/test_deg2rad.py +++ b/python/paddle/fluid/tests/unittests/test_deg2rad.py @@ -66,7 +66,7 @@ class TestDeg2radAPI(unittest.TestCase): class TestDeg2radAPI2(TestDeg2radAPI): # Test input data type is int def setUp(self): - self.x_np = 180 + self.x_np = [180] self.x_shape = [1] self.out_np = np.pi self.x_dtype = 'int64' diff --git a/python/paddle/fluid/tests/unittests/test_executor_feed_non_tensor.py b/python/paddle/fluid/tests/unittests/test_executor_feed_non_tensor.py index eaf8857f6acf48ccfb87c4a0be6ac6648e309c32..c1389ac2e1a67438c5e439ebc9666baf5d6cbb72 100644 --- a/python/paddle/fluid/tests/unittests/test_executor_feed_non_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_executor_feed_non_tensor.py @@ -22,7 +22,7 @@ from paddle import fluid class TestExecutor(unittest.TestCase): def net(self): - lr = paddle.static.data(name="lr", shape=[1], dtype='float32') + lr = paddle.static.data(name="lr", shape=[], dtype='float32') x = paddle.static.data(name="x", shape=[None, 1], dtype='float32') y = paddle.static.data(name="y", shape=[None, 1], dtype='float32') y_predict = paddle.static.nn.fc(x, size=1) diff --git a/python/paddle/fluid/tests/unittests/test_gcd.py b/python/paddle/fluid/tests/unittests/test_gcd.py index 2ed8438ee56835d035afc24d65fda1e2e7a0e97d..738c040ea98908e79304682a51581490d10efa81 100644 --- a/python/paddle/fluid/tests/unittests/test_gcd.py +++ b/python/paddle/fluid/tests/unittests/test_gcd.py @@ -25,8 +25,8 @@ paddle.enable_static() class TestGcdAPI(unittest.TestCase): def setUp(self): - self.x_np = 12 - self.y_np = 20 + self.x_np = [12] + self.y_np = [20] self.x_shape = [1] self.y_shape = [1] @@ -81,14 +81,14 @@ class TestGcdAPI3(TestGcdAPI): def setUp(self): self.x_np = 0 self.y_np = 20 - self.x_shape = [1] - self.y_shape = [1] + self.x_shape = [] + self.y_shape = [] class TestGcdAPI4(TestGcdAPI): def setUp(self): - self.x_np = 0 - self.y_np = 0 + self.x_np = [0] + self.y_np = [0] self.x_shape = [1] self.y_shape = [1] @@ -97,5 +97,5 @@ class TestGcdAPI5(TestGcdAPI): def setUp(self): self.x_np = 12 self.y_np = -20 - self.x_shape = [1] - self.y_shape = [1] + self.x_shape = [] + self.y_shape = [] diff --git a/python/paddle/fluid/tests/unittests/test_lcm.py b/python/paddle/fluid/tests/unittests/test_lcm.py index bb846a80d6ab27a2751321f37e9502308d0b7217..478853d8bab8ff4262513e72225777477b237bc7 100644 --- a/python/paddle/fluid/tests/unittests/test_lcm.py +++ b/python/paddle/fluid/tests/unittests/test_lcm.py @@ -27,8 +27,8 @@ class TestLcmAPI(unittest.TestCase): def setUp(self): self.x_np = 12 self.y_np = 20 - self.x_shape = [1] - self.y_shape = [1] + self.x_shape = [] + self.y_shape = [] def test_static_graph(self): startup_program = fluid.Program() @@ -81,14 +81,14 @@ class TestLcmAPI3(TestLcmAPI): def setUp(self): self.x_np = 0 self.y_np = 20 - self.x_shape = [1] - self.y_shape = [1] + self.x_shape = [] + self.y_shape = [] class TestLcmAPI4(TestLcmAPI): def setUp(self): - self.x_np = 0 - self.y_np = 0 + self.x_np = [0] + self.y_np = [0] self.x_shape = [1] self.y_shape = [1] @@ -97,5 +97,5 @@ class TestLcmAPI5(TestLcmAPI): def setUp(self): self.x_np = 12 self.y_np = -20 - self.x_shape = [1] - self.y_shape = [1] + self.x_shape = [] + self.y_shape = [] diff --git a/python/paddle/fluid/tests/unittests/test_put_along_axis_op.py b/python/paddle/fluid/tests/unittests/test_put_along_axis_op.py index a2085cc416a1953319c8a0b9c6bfeda2249e825b..2834cda0be6966a4ee18587c76835988078572a8 100644 --- a/python/paddle/fluid/tests/unittests/test_put_along_axis_op.py +++ b/python/paddle/fluid/tests/unittests/test_put_along_axis_op.py @@ -141,7 +141,7 @@ class TestPutAlongAxisAPI(unittest.TestCase): self.place = [paddle.CPUPlace()] self.axis = 0 self.value_np = 99.0 - self.value_shape = [1] + self.value_shape = [] self.x_feed = copy.deepcopy(self.x_np) if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) @@ -240,7 +240,7 @@ class TestPutAlongAxisAPICase2(TestPutAlongAxisAPI): self.place = [paddle.CPUPlace()] self.axis = 0 self.value_np = 99.0 - self.value_shape = [1] + self.value_shape = [] self.x_feed = copy.deepcopy(self.x_np) if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) @@ -258,7 +258,7 @@ class TestPutAlongAxisAPICase3(TestPutAlongAxisAPI): self.place = [paddle.CPUPlace()] self.axis = 0 self.value_np = 99.0 - self.value_shape = [1] + self.value_shape = [] self.x_feed = copy.deepcopy(self.x_np) if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) diff --git a/python/paddle/fluid/tests/unittests/test_rad2deg.py b/python/paddle/fluid/tests/unittests/test_rad2deg.py index b6fceed10be06fcdb223c24903136aad5e00a4a8..710d77f0d9fc1b4525f663fe53a68bab06935b10 100644 --- a/python/paddle/fluid/tests/unittests/test_rad2deg.py +++ b/python/paddle/fluid/tests/unittests/test_rad2deg.py @@ -65,7 +65,7 @@ class TestRad2degAPI(unittest.TestCase): class TestRad2degAPI2(TestRad2degAPI): def setUp(self): - self.x_np = np.pi / 2 + self.x_np = [np.pi / 2] self.x_shape = [1] self.out_np = 90 self.x_dtype = 'float32' @@ -83,7 +83,7 @@ class TestRad2degAPI2(TestRad2degAPI): class TestRad2degAPI3(TestRad2degAPI): # Test input data type is int def setUp(self): - self.x_np = 1 + self.x_np = [1] self.x_shape = [1] self.out_np = 180 / np.pi self.x_dtype = 'int64' diff --git a/python/paddle/fluid/tests/unittests/test_trapezoid.py b/python/paddle/fluid/tests/unittests/test_trapezoid.py index 226f40db91ab3d67cc72ad70dc7f16b138d8723a..de18e75512717c328d17ad32baf160041ba625c7 100644 --- a/python/paddle/fluid/tests/unittests/test_trapezoid.py +++ b/python/paddle/fluid/tests/unittests/test_trapezoid.py @@ -83,7 +83,7 @@ class TestTrapezoidAPI(unittest.TestCase): ) if self.dx is not None: dx = paddle.static.data( - name="dx", shape=[1], dtype='float32' + name="dx", shape=[], dtype='float32' ) exe = paddle.static.Executor(place) diff --git a/python/paddle/fluid/tests/unittests/test_unbind_op.py b/python/paddle/fluid/tests/unittests/test_unbind_op.py index 4d5c4f9beefae323b8db9fa01614fac208e37496..989eb43b0504d0e0ee58531dbd20816a404883f3 100644 --- a/python/paddle/fluid/tests/unittests/test_unbind_op.py +++ b/python/paddle/fluid/tests/unittests/test_unbind_op.py @@ -29,7 +29,7 @@ class TestUnbind(unittest.TestCase): x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1') [out_0, out_1] = tensor.unbind(input=x_1, axis=0) input_1 = np.random.random([2, 3]).astype("float32") - axis = paddle.static.data(shape=[1], dtype='int32', name='axis') + axis = paddle.static.data(shape=[], dtype='int32', name='axis') exe = fluid.Executor(place=fluid.CPUPlace()) [res_1, res_2] = exe.run( @@ -87,7 +87,7 @@ class TestLayersUnbind(unittest.TestCase): x_1 = paddle.static.data(shape=[2, 3], dtype='float32', name='x_1') [out_0, out_1] = paddle.unbind(input=x_1, axis=0) input_1 = np.random.random([2, 3]).astype("float32") - axis = paddle.static.data(shape=[1], dtype='int32', name='axis') + axis = paddle.static.data(shape=[], dtype='int32', name='axis') exe = fluid.Executor(place=fluid.CPUPlace()) [res_1, res_2] = exe.run( diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index fa6e7ce19db84738324a7fa9c23dd55644005a03..ca9f369d01cf697bd63ebc1a4bcedcbe37bda4fd 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -17,6 +17,7 @@ # 0D Tensor's shape is always [], numel is 1 # which can be created by paddle.rand([]) +import os import unittest import numpy as np @@ -1712,6 +1713,75 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(out.grad.shape, [2]) + def test_gather_nd(self): + x1 = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) + x2 = paddle.to_tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False + ) + + index1 = paddle.full([1], 1, 'int64') + index2 = paddle.full([2], 1, 'int64') + + out1 = paddle.gather_nd(x1, index1) + out2 = paddle.gather_nd(x2, index2) + + out1.retain_grads() + out2.retain_grads() + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + np.testing.assert_array_equal(out1, np.array(3.0)) + np.testing.assert_array_equal(out2, np.array(5.0)) + self.assertEqual(x1.grad.shape, [5]) + self.assertEqual(x2.grad.shape, [2, 3]) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(out2.grad.shape, []) + + def test_einsum(self): + os.environ['FLAGS_new_einsum'] = "0" + x = paddle.rand([5]) + # sum + out1 = paddle.einsum('i->', x) + expect1 = np.einsum('i->', x) + # dot + out2 = paddle.einsum('i,i->', x, x) + expect2 = np.einsum('i,i->', x, x) + + out1.retain_grads() + out2.retain_grads() + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + np.testing.assert_allclose(out1, expect1, rtol=1e-03) + np.testing.assert_allclose(out2, expect2, rtol=1e-03) + + def test_einsum_V2(self): + os.environ['FLAGS_new_einsum'] = "1" + x = paddle.rand([5]) + # sum + out1 = paddle.einsum('i->', x) + expect1 = np.einsum('i->', x) + # dot + out2 = paddle.einsum('i,i->', x, x) + expect2 = np.einsum('i,i->', x, x) + + out1.retain_grads() + out2.retain_grads() + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + np.testing.assert_allclose(out1, expect1, rtol=1e-03) + np.testing.assert_allclose(out2, expect2, rtol=1e-03) + def test_scatter_1D(self): x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) index = paddle.full([], 2, 'int64') @@ -2324,6 +2394,56 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out1.shape, [2, 3, 12, 12]) self.assertEqual(input_x.grad.shape, [2, 3, 6, 6]) + def test_unstack(self): + x1 = paddle.full([1], 0) + x2 = paddle.full([2], 2) + x1.retain_grads() + x2.retain_grads() + x1.stop_gradient = False + x2.stop_gradient = False + + [out1] = paddle.unstack(x1, 0) + out1.retain_grads() + out1.backward() + [out2_1, out2_2] = paddle.unstack(x2, 0) + out2 = paddle.add_n([out2_1, out2_2]) + out2.retain_grads() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out1.numpy(), 0) + + self.assertEqual(out2_1.shape, []) + self.assertEqual(out2_1.numpy(), 2) + self.assertEqual(out2_2.shape, []) + self.assertEqual(out2_2.numpy(), 2) + self.assertEqual(x2.grad.shape, [2]) + + def test_unbind(self): + x1 = paddle.full([1], 0) + x2 = paddle.full([2], 2) + x1.retain_grads() + x2.retain_grads() + x1.stop_gradient = False + x2.stop_gradient = False + + [out1] = paddle.unbind(x1, 0) + out1.retain_grads() + out1.backward() + [out2_1, out2_2] = paddle.unbind(x2, 0) + out2 = paddle.add_n([out2_1, out2_2]) + out2.retain_grads() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out1.numpy(), 0) + + self.assertEqual(out2_1.shape, []) + self.assertEqual(out2_1.numpy(), 2) + self.assertEqual(out2_2.shape, []) + self.assertEqual(out2_2.numpy(), 2) + self.assertEqual(x2.grad.shape, [2]) + def test_maseked_select(self): x = paddle.rand([]) x.stop_gradient = False @@ -2338,6 +2458,26 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.numpy(), 1) + def test_squeeze(self): + x1 = paddle.full([], 2) + x1.stop_gradient = False + x1.retain_grads() + out1 = paddle.squeeze(x1, axis=0) + out1.retain_grads() + out1.backward() + self.assertEqual(out1.shape, []) + self.assertEqual(x1.grad.shape, []) + + x2 = paddle.full([], 3) + x3 = paddle.full([1], 0, dtype='int32') + x2.stop_gradient = False + x2.retain_grads() + out2 = paddle.squeeze(x2, axis=x3) + out2.retain_grads() + out2.backward() + self.assertEqual(out2.shape, []) + self.assertEqual(x2.grad.shape, []) + def test_unsqueeze(self): x1 = paddle.full([], 2) x1.stop_gradient = False @@ -3713,6 +3853,42 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[1].shape, (2, 3)) self.assertEqual(res[2].shape, (2,)) + @prog_scope() + def test_gather_nd(self): + x1 = paddle.full([10], 1.0, 'float32') + x1.stop_gradient = False + x2 = paddle.full([2, 3], 1.0, 'float32') + x2.stop_gradient = False + + index1 = paddle.full([1], 1, 'int64') + index2 = paddle.full([2], 1, 'int64') + + out1 = paddle.gather_nd(x1, index1) + out2 = paddle.gather_nd(x2, index2) + paddle.static.append_backward(out1.sum()) + paddle.static.append_backward(out2.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[ + out1, + out2, + x1.grad_name, + x2.grad_name, + out1.grad_name, + out2.grad_name, + ], + ) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + np.testing.assert_array_equal(res[0], 1.0) + np.testing.assert_array_equal(res[1], 1.0) + self.assertEqual(res[2].shape, (10,)) + self.assertEqual(res[3].shape, (2, 3)) + self.assertEqual(res[4].shape, ()) + self.assertEqual(res[5].shape, ()) + @prog_scope() def test_scatter_1D(self): x = paddle.full([10], 1.0, 'float32') @@ -4329,6 +4505,50 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res1[0].shape, (2, 3, 12, 12)) self.assertEqual(res1[1].shape, (2, 3, 6, 6)) + @prog_scope() + def test_unstack(self): + x1 = paddle.full([1], 0, 'float32') + x1.stop_gradient = False + out1 = paddle.unstack(x1, 0) + out1 = paddle.add_n(out1) + paddle.static.append_backward(out1) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (1,)) + + x2 = paddle.full([2], 2, 'float32') + x2.stop_gradient = False + out2 = paddle.unstack(x2, 0) + out2_sum = paddle.add_n(out2) + paddle.static.append_backward(out2_sum) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (2,)) + + @prog_scope() + def test_unbind(self): + x1 = paddle.full([1], 0, 'float32') + x1.stop_gradient = False + out1 = paddle.unbind(x1, 0) + out1 = paddle.add_n(out1) + paddle.static.append_backward(out1) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (1,)) + + x2 = paddle.full([2], 2, 'float32') + x2.stop_gradient = False + out2 = paddle.unbind(x2, 0) + out2_sum = paddle.add_n(out2) + paddle.static.append_backward(out2_sum) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (2,)) + @prog_scope() def test_maseked_select(self): x = paddle.rand([]) @@ -4345,6 +4565,34 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[3].shape, ()) self.assertEqual(res[3], 1) + @prog_scope() + def test_squeeze(self): + x1 = paddle.full([], 2) + x1.stop_gradient = False + out1 = paddle.squeeze(x1, axis=0) + paddle.static.append_backward(out1.sum()) + + x2 = paddle.full([], 3) + x3 = paddle.full([], 0, dtype='int32') + x2.stop_gradient = False + out2 = paddle.squeeze(x2, axis=x3) + paddle.static.append_backward(out2.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[ + out1, + out2, + x1.grad_name, + x2.grad_name, + ], + ) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[3].shape, ()) + @prog_scope() def test_unsqueeze(self): x1 = paddle.full([], 2) @@ -4403,6 +4651,39 @@ class TestSundryAPIStatic(unittest.TestCase): res = self.exe.run(prog, feed={"x": x_tensor}, fetch_list=[out]) self.assertEqual(res[0].shape, (3, 4, 2)) + @prog_scope() + def test_static_data(self): + x1 = paddle.static.data(name="x1", shape=[]) + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + feed={ + "x1": np.array(1.0, dtype='float32'), + }, + fetch_list=[ + x1.name, + ], + ) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], np.array(1.0)) + + x2 = paddle.static.data(name="x2", shape=[]) + x3 = paddle.static.data(name="x3", shape=[]) + y = x2 + x3 + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + feed={ + "x2": 100.5, + "x3": 200.5, + }, + fetch_list=[ + y.name, + ], + ) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 301.0) + @prog_scope() def test_prelu(self): x1 = paddle.full([], 1.0, 'float32') diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 7ab104dde94b0e171451be6431b534fe7dd054fe..082300763740a0d0e2484f6fd11acfc5c737f5e3 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -966,8 +966,8 @@ def einsum(equation, *operands): # dot print(paddle.einsum('i,i->', x, x)) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [1.45936954]) + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # 1.45936954) # outer print(paddle.einsum("i,j->ij", x, y))