From f84ac449ad5912350a9a567438d0f4eea4dfa60a Mon Sep 17 00:00:00 2001 From: wangfengsheng1999 <129241980+wangfengsheng1999@users.noreply.github.com> Date: Thu, 27 Apr 2023 05:00:44 -0700 Subject: [PATCH] [Cherry-Pick]Support output 0D for is_empty/as_complex/inner/dot/rank/tensordot/squeeze_/static.accuracy/static.auc/metric.accuracy (#53199) * support output 0D for is_empty/as_complex/inner/dot/rank/tensordot/squeeze_/static.accuracy/static.auc/metric.accuracy * test_dot_py * test_dot_py --- paddle/phi/infermeta/binary.cc | 5 +- paddle/phi/infermeta/multiary.cc | 2 +- paddle/phi/infermeta/ternary.cc | 6 +- paddle/phi/infermeta/unary.cc | 2 +- paddle/phi/kernels/gpu/dot_kernel.cu | 4 +- .../phi/kernels/impl/dot_grad_kernel_impl.h | 12 +- .../tests/unittests/check_nan_inf_base.py | 2 +- .../fluid/tests/unittests/test_dot_op.py | 19 +- .../fluid/tests/unittests/test_layers.py | 3 +- .../fluid/tests/unittests/test_nan_inf.py | 8 +- .../tests/unittests/test_zero_dim_tensor.py | 383 +++++++++++++++++- python/paddle/tensor/manipulation.py | 3 - python/paddle/tensor/math.py | 3 +- test/cpp/phi/core/test_custom_kernel.cc | 2 +- ...t2_int8_image_classification_comparison.py | 2 +- 15 files changed, 414 insertions(+), 42 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 054563559fb..c6ca65e273a 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1152,8 +1152,9 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { "with input tensor Y: %s", x_dims.to_str(), y_dims.to_str())); - - x_dims[x_dims.size() - 1] = 1; + std::vector x_dims_vec = phi::vectorize(x_dims); + std::vector x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 1); + x_dims = phi::make_ddim(x_dims_vec_cut); out->set_dims(x_dims); out->set_dtype(x.dtype()); out->set_layout(x.layout()); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 31aa85245eb..ea9588f0e45 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -479,7 +479,7 @@ void AucInferMeta(const MetaTensor& input, 0, phi::errors::InvalidArgument("slide_steps must be natural number")); - auc->set_dims({1}); + auc->set_dims(phi::make_ddim({})); auc->set_dtype(DataType::INT64); if (slide_steps) { diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 4877e16675e..64676770315 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -66,11 +66,11 @@ void AccuracyInferMeta(const MetaTensor& out, label_dim[0])); } - accuracy->set_dims({1}); + accuracy->set_dims(phi::make_ddim({})); + correct->set_dims(phi::make_ddim({})); + total->set_dims(phi::make_ddim({})); accuracy->set_dtype(out.dtype()); - correct->set_dims({1}); correct->set_dtype(out.dtype()); - total->set_dims({1}); total->set_dtype(out.dtype()); accuracy->share_lod(out); } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c1ee2b5d4ec..4b0bc319fc4 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1839,7 +1839,7 @@ void InverseInferMeta(const MetaTensor& x, MetaTensor* out) { } void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out) { - out->set_dims(phi::make_ddim({1})); + out->set_dims(phi::make_ddim({})); out->set_dtype(DataType::BOOL); } diff --git a/paddle/phi/kernels/gpu/dot_kernel.cu b/paddle/phi/kernels/gpu/dot_kernel.cu index 5005f6390d2..72679b51899 100644 --- a/paddle/phi/kernels/gpu/dot_kernel.cu +++ b/paddle/phi/kernels/gpu/dot_kernel.cu @@ -32,7 +32,7 @@ void DotKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out) { dev_ctx.template Alloc(out); - if (1 == out->dims().size()) { + if (out->dims().size() == 0) { auto eigen_out = phi::EigenScalar::From(*out); auto eigen_x = phi::EigenVector::Flatten(x); auto eigen_y = phi::EigenVector::Flatten(y); @@ -40,7 +40,7 @@ void DotKernel(const Context& dev_ctx, auto& dev = *dev_ctx.eigen_device(); eigen_out.device(dev) = (eigen_x * eigen_y).sum(); } else { - auto eigen_out = phi::EigenMatrix::From(*out); + auto eigen_out = phi::EigenVector::From(*out); auto eigen_x = phi::EigenMatrix::From(x); auto eigen_y = phi::EigenMatrix::From(y); diff --git a/paddle/phi/kernels/impl/dot_grad_kernel_impl.h b/paddle/phi/kernels/impl/dot_grad_kernel_impl.h index 6bfe6d65245..add72749d39 100644 --- a/paddle/phi/kernels/impl/dot_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/dot_grad_kernel_impl.h @@ -46,7 +46,7 @@ struct DotGradFunction> { DenseTensor* tensor_dy) { VLOG(1) << "enable route"; #if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { + if (1 >= tensor_dout->dims().size()) { auto dout = EigenVector::Flatten(*tensor_dout); if (tensor_dx) { @@ -144,7 +144,7 @@ struct DotGradFunction> { DenseTensor* tensor_dx, DenseTensor* tensor_dy) { #if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { + if (1 >= tensor_dout->dims().size()) { auto dout = EigenVector::Flatten(*tensor_dout); if (tensor_dx) { auto y = EigenVector::Flatten(*tensor_y); @@ -236,7 +236,7 @@ struct DotDoubleGradFunction> { const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr(); const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr(); #if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { + if (1 >= tensor_dout->dims().size()) { DenseTensor tensor_dout_help; auto& dev = *ctx.eigen_device(); if (tensor_dx || tensor_dy) { @@ -431,7 +431,7 @@ struct DotDoubleGradFunction> { const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr(); const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr(); #if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { + if (1 >= tensor_dout->dims().size()) { auto& dev = *ctx.eigen_device(); auto x = EigenVector::Flatten(*tensor_x); auto y = EigenVector::Flatten(*tensor_y); @@ -621,7 +621,7 @@ struct DotTripleGradFunction> { const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr(); const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr(); #if defined(__NVCC__) || defined(__HIPCC__) - if (1 == in_tensor_dout->dims().size()) { + if (1 >= in_tensor_dout->dims().size()) { auto& dev = *ctx.eigen_device(); DenseTensor in_tensor_x_help = Conj(ctx, *in_tensor_x); DenseTensor in_tensor_y_help = Conj(ctx, *in_tensor_y); @@ -1015,7 +1015,7 @@ struct DotTripleGradFunction> { const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr(); const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr(); #if defined(__NVCC__) || defined(__HIPCC__) - if (1 == in_tensor_dout->dims().size()) { + if (1 >= in_tensor_dout->dims().size()) { auto& dev = *ctx.eigen_device(); bool d_dout_flag = false; bool d_ddx_flag = false; diff --git a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py index 13ab70165a3..4237987f997 100644 --- a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py +++ b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py @@ -97,7 +97,7 @@ def check(use_cuda): step += 1 print( 'iter={:.0f},cost={},acc1={}'.format( - step, outs[1][0], outs[2][0] + step, outs[1][0], outs[2] ) ) diff --git a/python/paddle/fluid/tests/unittests/test_dot_op.py b/python/paddle/fluid/tests/unittests/test_dot_op.py index 5cb061c368b..9d0e0b1b15d 100644 --- a/python/paddle/fluid/tests/unittests/test_dot_op.py +++ b/python/paddle/fluid/tests/unittests/test_dot_op.py @@ -106,8 +106,7 @@ class DotOpEmptyInput(unittest.TestCase): x = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32') y = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32') pd_out = paddle.dot(x, y) - - self.assertEqual(pd_out.shape, (0, 1)) + self.assertEqual(pd_out.shape, (0,)) def test_3d_input_error(self): data = np.array([], dtype=np.float32) @@ -127,7 +126,7 @@ class DotOpBatch(DotOp): self.y = ( np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12]) ) - self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) + self.out = np.sum(self.x * self.y, axis=1) def test_check_grad_normal(self): self.check_grad(['X', 'Y'], 'Out') @@ -180,7 +179,7 @@ class TestDygraph(unittest.TestCase): np.array([[2, 5], [6, 8]]).astype(np.float32) ) np.testing.assert_array_equal( - paddle.dot(x1, y1).numpy(), np.array([[17], [58]]) + paddle.dot(x1, y1).numpy(), np.array([17, 58]) ) @@ -211,7 +210,7 @@ class TestComplexDotOp(OpTest): self.out = np.dot(self.x, self.y) def init_grad_input_output(self): - self.grad_out = np.ones(1, self.dtype) + 1j * np.ones(1, self.dtype) + self.grad_out = np.ones([], self.dtype) + 1j * np.ones([], self.dtype) self.grad_x = self.grad_out * np.conj(self.y) self.grad_y = self.grad_out * np.conj(self.x) @@ -269,12 +268,10 @@ class TestComplexDotOp2D(OpTest): self.y = np.random.random((2, 100)).astype( self.dtype ) + 1j * np.random.random((2, 100)).astype(self.dtype) - self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1) + self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1) def init_grad_input_output(self): - self.grad_out = np.ones((2, 1), self.dtype) + 1j * np.ones( - (2, 1), self.dtype - ) + self.grad_out = np.ones((2), self.dtype) + 1j * np.ones((2), self.dtype) self.grad_x = self._get_grad(self.grad_out, self.y) self.grad_y = self._get_grad(self.grad_out, self.x) @@ -381,7 +378,7 @@ class DotFP16OpBatch(TestDotFP16Op): self.y = ( np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12]) ) - self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) + self.out = np.sum(self.x * self.y, axis=1) @unittest.skipIf( @@ -468,7 +465,7 @@ class DotBF16OpBatch(TestDotBF16Op): self.y = ( np.random.uniform(1, 3, [132]).astype(np.float32).reshape([11, 12]) ) - self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) + self.out = np.sum(self.x * self.y, axis=1) def test_check_grad_normal(self): if core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index b68ee02d207..8d7726794fe 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1426,8 +1426,9 @@ class TestLayer(LayerTest): exe.run(fluid.default_startup_program()) # x = np.random.rand(3, 32, 32).astype("float32") # y = np.array([[1], [0], [1]]) + static_out = exe.run( - feed={"input": x, "label": y}, fetch_list=result[0] + feed={"input": x, "label": y}, fetch_list=result ) with self.dynamic_graph(force_to_use_cpu=True): diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py index 08bea2afa65..5e233e26a62 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -50,11 +50,15 @@ class TestNanInf(unittest.TestCase): assert (out + err).find(b'There are NAN or INF') != -1 def test_nan_inf_in_static_mode(self): - self._python_interp += " check_nan_inf_base.py" + self._python_interp += ( + " " + os.path.dirname(__file__) + "/check_nan_inf_base.py" + ) self.check_nan_inf() def test_nan_inf_in_dynamic_mode(self): - self._python_interp += " check_nan_inf_base_dygraph.py" + self._python_interp += ( + " " + os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py" + ) self.check_nan_inf() diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 3b909b78225..586de505bed 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -991,6 +991,158 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x1.grad.shape, [5]) + def test_is_empty(self): + # 1) x is 0D + x = paddle.rand([]) + out = paddle.is_empty(x) + self.assertFalse(out) + self.assertEqual(out.shape, []) + + # 2) x is 1D + x = paddle.rand([5]) + out = paddle.is_empty(x) + self.assertFalse(out) + self.assertEqual(out.shape, []) + + # 3) x is ND + x = paddle.rand([3, 5]) + out = paddle.is_empty(x) + self.assertFalse(out) + self.assertEqual(out.shape, []) + + x = paddle.rand([3, 0, 5]) + out = paddle.is_empty(x) + self.assertTrue(out) + self.assertEqual(out.shape, []) + + def test_squeeze_(self): + # 1) x is 0D + x = paddle.rand([]) + x.squeeze_(0) + self.assertEqual(x.shape, []) + + # 2) x is 1D + x = paddle.rand([1]) + x.squeeze_(0) + self.assertEqual(x.shape, []) + + # 3)x is ND + x = paddle.rand([2, 1]) + x.squeeze_(1) + self.assertEqual(x.shape, [2]) + + def test_as_complex(self): + x = paddle.rand([2]) + x.stop_gradient = False + out = paddle.as_complex(x) + out.retain_grads() + out.backward() + + self.assertEqual(x.shape, [2]) + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, [2]) + self.assertEqual(out.grad.shape, []) + + def test_dot(self): + # 1) x is 1D + x = paddle.rand([2]) + x.stop_gradient = False + y = paddle.rand([2]) + y.stop_gradient = False + out = paddle.dot(x, y) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [2]) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + # 2) x is 2D + x1 = paddle.rand([2, 2]) + x1.stop_gradient = False + y1 = paddle.rand([2, 2]) + y1.stop_gradient = False + out1 = paddle.dot(x1, y1) + out1.retain_grads() + out1.backward() + + self.assertEqual(x1.grad.shape, [2, 2]) + self.assertEqual(out1.shape, [2]) + self.assertEqual(out1.grad.shape, [2]) + + def test_inner(self): + # 0) input is 0D + x = paddle.rand([]) + x.stop_gradient = False + y = paddle.rand([]) + y.stop_gradient = False + out = paddle.inner(x, y) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + # 1) input is 1D + x = paddle.rand([2]) + x.stop_gradient = False + y = paddle.rand([2]) + y.stop_gradient = False + out = paddle.inner(x, y) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [2]) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + # 2) input is 2D + x = paddle.rand([2, 3]) + x.stop_gradient = False + y = paddle.rand([3, 3]) + y.stop_gradient = False + out = paddle.inner(x, y) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [2, 3]) + self.assertEqual(out.shape, [2, 3]) + self.assertEqual(out.grad.shape, [2, 3]) + + def test_tensordot(self): + + # 1) input is 1D + x = paddle.arange(10, dtype='float64') + x.stop_gradient = False + y = paddle.arange(10, dtype='float64') + y.stop_gradient = False + out = paddle.tensordot(x, y, axes=1) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [10]) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + # 2) input is 2D + x = paddle.arange(6, dtype='float64').reshape([2, 3]) + y = paddle.arange(6, dtype='float64').reshape([2, 3]) + x.stop_gradient = False + out = paddle.tensordot(x, y, axes=2) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [2, 3]) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + def test_metric_accuracy(self): + x = paddle.full(shape=[2, 4], fill_value=0.25) + y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64") + out = paddle.metric.accuracy(input=x, label=y, k=1) + self.assertEqual(out.shape, []) + def test_std(self): x = paddle.rand([]) x.stop_gradient = False @@ -1098,10 +1250,6 @@ class TestSundryAPI(unittest.TestCase): def test_is_tensor(self): self.assertTrue(paddle.is_tensor(self.x)) - def test_is_empty(self): - x = paddle.rand([3, 0, 5]) - self.assertTrue(paddle.is_empty(x)) - def test_isfinite(self): out = paddle.isfinite(self.x) np.testing.assert_array_equal(out.numpy(), np.array(True)) @@ -1160,7 +1308,8 @@ class TestSundryAPI(unittest.TestCase): def test_rank(self): # 1) x is 0D - out = paddle.rank(self.x) + x = paddle.rand([]) + out = paddle.rank(x) self.assertEqual(out.shape, []) np.testing.assert_array_equal(out.numpy(), np.array(0)) @@ -2456,6 +2605,230 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, (5,)) + @prog_scope() + def test_is_empty(self): + # 1) x is 0D + x1 = paddle.rand([]) + out1 = paddle.is_empty(x1) + + # 2) x is 1D + x2 = paddle.rand([5]) + out2 = paddle.is_empty(x2) + + # 3) x is ND + x3 = paddle.rand([3, 5]) + out3 = paddle.is_empty(x3) + + x4 = paddle.rand([3, 0, 5]) + out4 = paddle.is_empty(x4) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[out1, out2, out3, out4], + ) + + self.assertEqual(res[0].shape, ()) + self.assertFalse(bool(res[0])) + self.assertEqual(res[1].shape, ()) + self.assertFalse(bool(res[1])) + self.assertEqual(res[2].shape, ()) + self.assertFalse(bool(res[2])) + self.assertEqual(res[3].shape, ()) + self.assertTrue(bool(res[3])) + + @prog_scope() + def test_as_complex(self): + x = paddle.rand([2]) + x.stop_gradient = False + out = paddle.as_complex(x) + self.assertEqual(x.shape, (2,)) + self.assertEqual(out.shape, ()) + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[x, out, x.grad_name, out.grad_name], + ) + + self.assertEqual(res[0].shape, (2,)) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, (2,)) + self.assertEqual(res[3].shape, ()) + + @prog_scope() + def test_dot(self): + # 1) x is 1d + x = paddle.rand([2]) + x.stop_gradient = False + y = paddle.rand([2]) + y.stop_gradient = False + out = paddle.dot(x, y) + + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[x, x.grad_name, out, out.grad_name], + ) + + self.assertEqual(res[0].shape, (2,)) + self.assertEqual(res[1].shape, (2,)) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[3].shape, ()) + + # 2) x is 2D + x1 = paddle.rand([2, 2]) + x1.stop_gradient = False + y1 = paddle.rand([2, 2]) + y1.stop_gradient = False + out1 = paddle.dot(x1, y1) + + paddle.static.append_backward(out1.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[x1, x1.grad_name, out1, out1.grad_name], + ) + + self.assertEqual(res[0].shape, (2, 2)) + self.assertEqual(res[1].shape, (2, 2)) + self.assertEqual(res[2].shape, (2,)) + self.assertEqual(res[3].shape, (2,)) + + @prog_scope() + def test_inner(self): + # 1) input is 1D + x1 = paddle.rand([2]) + x1.stop_gradient = False + y1 = paddle.rand([2]) + y1.stop_gradient = False + out1 = paddle.inner(x1, y1) + paddle.static.append_backward(out1.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[ + x1, + x1.grad_name, + out1, + out1.grad_name, + ], + ) + self.assertEqual(res[0].shape, (2,)) + self.assertEqual(res[1].shape, (2,)) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[3].shape, ()) + + # 2) input is 2D + x = paddle.rand([2, 3]) + x.stop_gradient = False + y = paddle.rand([2, 3]) + y.stop_gradient = False + out = paddle.inner(x, y) + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[ + x, + x.grad_name, + out, + out.grad_name, + ], + ) + + self.assertEqual(res[0].shape, (2, 3)) + self.assertEqual(res[1].shape, (2, 3)) + self.assertEqual(res[2].shape, (2, 2)) + self.assertEqual(res[3].shape, (2, 2)) + + @prog_scope() + def test_tensordot(self): + x = paddle.full(shape=[10], fill_value=0.25, dtype='float64') + x.stop_gradient = False + y = paddle.full(shape=[10], fill_value=0.25, dtype='float64') + y.stop_gradient = False + out = paddle.tensordot(x, y, axes=1) + + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[x, x.grad_name, out, out.grad_name], + ) + + self.assertEqual(res[0].shape, (10,)) + self.assertEqual(res[1].shape, (10,)) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[3].shape, ()) + + x = paddle.arange(6, dtype='float64').reshape([2, 3]) + y = paddle.arange(6, dtype='float64').reshape([2, 3]) + x.stop_gradient = False + out = paddle.tensordot(x, y, axes=2) + + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[x, x.grad_name, out, out.grad_name], + ) + + self.assertEqual(res[0].shape, (2, 3)) + self.assertEqual(res[1].shape, (2, 3)) + self.assertEqual(res[2].shape, ()) + self.assertEqual(res[3].shape, ()) + + @prog_scope() + def test_metric_accuracy(self): + x = paddle.full(shape=[2, 4], fill_value=0.25) + y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64") + out = paddle.metric.accuracy(input=x, label=y, k=1) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[out], + ) + + self.assertEqual(res[0].shape, ()) + + @prog_scope() + def test_static_accuracy(self): + x = paddle.full(shape=[2, 4], fill_value=0.25) + y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64") + out = paddle.static.accuracy(input=x, label=y, k=1) + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[out], + ) + + self.assertEqual(res[0].shape, ()) + + @prog_scope() + def test_static_auc(self): + x = paddle.full(shape=[3, 2], fill_value=0.25) + y = paddle.full(shape=[3], fill_value=1, dtype="int64") + out = paddle.static.auc(input=x, label=y)[0] + + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, + fetch_list=[out], + ) + + self.assertEqual(res[0].shape, ()) + @prog_scope() def test_std(self): x = paddle.rand([]) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 32e05851c95..fee880f5bf6 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4197,9 +4197,6 @@ def tensordot(x, y, axes=2, name=None): shape_out.append(shape_y[i]) not_contraction_size_y *= shape_y[i] - if not shape_out: - shape_out = [1] - x = x.transpose(perm=perm_x).reshape( [not_contraction_size_x, contraction_size] ) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ebbcbad581a..efa8cbfc54e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2070,8 +2070,7 @@ def inner(x, y, name=None): xshape = x.shape yshape = y.shape dstshape = list(xshape[:-1]) + list(yshape[:-1]) - if len(dstshape) == 0: - dstshape = [1] + nx = x.reshape((-1, xshape[-1])) ny = y.reshape((-1, yshape[-1])) diff --git a/test/cpp/phi/core/test_custom_kernel.cc b/test/cpp/phi/core/test_custom_kernel.cc index 78ae9c6e959..a8dcd89c87c 100644 --- a/test/cpp/phi/core/test_custom_kernel.cc +++ b/test/cpp/phi/core/test_custom_kernel.cc @@ -281,7 +281,7 @@ TEST(CustomKernel, custom_kernel_dot) { kernel(&kernel_context); // 8.check result - ASSERT_EQ(dense_out->dims().size(), 2); + ASSERT_EQ(dense_out->dims().size(), 1); ASSERT_EQ(dense_out->dims()[0], 2); ASSERT_EQ(dense_out->numel(), 2); ASSERT_EQ(dense_out->dtype(), phi::DataType::UINT8); diff --git a/test/quantization/quant2_int8_image_classification_comparison.py b/test/quantization/quant2_int8_image_classification_comparison.py index d3bd3a48dda..34d91851eaf 100644 --- a/test/quantization/quant2_int8_image_classification_comparison.py +++ b/test/quantization/quant2_int8_image_classification_comparison.py @@ -263,7 +263,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): fetch_list=fetch_targets, ) batch_time = (time.time() - start) * 1000 # in miliseconds - batch_acc1, batch_acc5 = out[1][0], out[2][0] + batch_acc1, batch_acc5 = out[1], out[2] outputs.append(batch_acc1) else: # Quant INT8 models do not have accuracy measuring layers -- GitLab