diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index 16ca2cf09ec0b34a65157f45258f02983f15edd4..3ac5c6e21434d5777c6e3d98e3bf1a84344ed306 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -27,7 +27,7 @@ namespace detail { static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { - return phi::make_ddim({1}); + return phi::make_ddim({}); } x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index c6ca65e273ad07939b2a3c9dcbf60969e56fb6be..efa271e8e1582c1a28e944bf3ce1155360284b19 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x, static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { - return phi::make_ddim({1}); + return phi::make_ddim({}); } x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); @@ -990,7 +990,7 @@ void DistInferMeta(const MetaTensor& x, "The Input(Y) has not been initialized properly. The " "shape of Input(Y) = [%s].", y_dims)); - out->set_dims({1}); + out->set_dims(phi::make_ddim({})); out->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index ea9588f0e455a526a9fd2e243db61576f9699578..19c8326a23cdf5a59c266ad6d49da53bcffc5a4d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2344,7 +2344,7 @@ void MultiDotInferMeta(const std::vector& x, // If the last tensor is 1D of size n view it as a column vector (n, 1) if (last_dim.size() == 1) { last_dim = phi::make_ddim({static_cast(last_dim[0]), 1}); - out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]}); + out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]}); } else { out_dim = is_vector ? phi::make_ddim({last_dim[1]}) : phi::make_ddim({first_dim[0], last_dim[1]}); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index cdca81ec040c6ffe62694f9dc6295366f2dea1a5..14ae02246babb573f1ceb74e73d9bc9a802ed590 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -38,7 +38,7 @@ namespace detail { static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { - return phi::make_ddim({1}); + return phi::make_ddim({}); } x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); @@ -4405,7 +4405,6 @@ void TraceInferMeta( auto sizes = vectorize(x_dims); if (x_dims.size() == 2) { sizes.clear(); - sizes.push_back(1); } else { sizes.erase(sizes.begin() + std::max(dim1_, dim2_)); sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); diff --git a/paddle/phi/kernels/gpu/trace_kernel.cu b/paddle/phi/kernels/gpu/trace_kernel.cu index 671ca490e136a2e5adb6e61eee7360f0ab9ff835..304bf778094d3de33b59b43ddaf7f886d2d4ceea 100644 --- a/paddle/phi/kernels/gpu/trace_kernel.cu +++ b/paddle/phi/kernels/gpu/trace_kernel.cu @@ -32,7 +32,10 @@ void TraceKernel(const Context& ctx, auto diag = funcs::Diagonal(ctx, &x, offset, axis1, axis2); if (diag.numel() > 0) { std::vector reduce_dims; - reduce_dims.push_back(out->dims().size()); + // Adapt to 0D output + auto out_dim_size = out->dims().size(); + if (out_dim_size == 0) out_dim_size = 1; + reduce_dims.push_back(out_dim_size); funcs::ReduceKernel>( ctx, diag, out, kps::IdentityFunctor(), reduce_dims); } else { diff --git a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h index 3f463e1d9e0644f62a0b3e8508f333731c466922..2e5625e3d8fc845d18cc61c8067d9ceafe7b7448 100644 --- a/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_grad_kernel_impl.h @@ -90,10 +90,10 @@ void DeterminantGradKernel(const Context& dev_ctx, " input tensor's, but here differ %d", input_dims_size - out_grad.dims().size())); } else if (input_dims_size == 2) { - // input dims size 2 and grad dims size 1 is possible + // input dims size 2 and grad dims size 0 is possible PADDLE_ENFORCE_EQ( out_grad.dims().size(), - 1, + 0, phi::errors::InvalidArgument( "The grad tensor of det dims size should be 2 less than" " input tensor's, but here differ %d", diff --git a/paddle/phi/kernels/impl/determinant_kernel_impl.h b/paddle/phi/kernels/impl/determinant_kernel_impl.h index 36e47c78c832c10fc8613604c47239508f44e725..410bf90c2899c1e71373ef31ccc7dd683aa42e4c 100644 --- a/paddle/phi/kernels/impl/determinant_kernel_impl.h +++ b/paddle/phi/kernels/impl/determinant_kernel_impl.h @@ -116,7 +116,7 @@ void DeterminantKernel(const Context& dev_ctx, out->Resize(output_dims); } else { // when input is a two-dimension matrix, The det value is a number. - out->Resize({1}); + out->Resize(phi::make_ddim({})); } VLOG(10) << "output dim:" << out->dims(); } diff --git a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h index 90a2327ef3e204fe0cda3cc281407926e0a61ba3..1099f27f3622e6156d9765f4af9dd60947de8714 100644 --- a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h @@ -91,7 +91,8 @@ void TraceGradKernel(const Context& ctx, auto input_dims = in_grad->dims(); auto input_stride = phi::stride(input_dims); auto output_dims = out_grad.dims(); - auto output_stride = phi::stride(output_dims); + auto output_stride = output_dims.size() == 0 ? phi::DDim(output_dims) + : phi::stride(output_dims); auto* out_data = out_grad.data(); T* x_data = ctx.template Alloc(in_grad); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 25252def81f70d1f047a89b02643ada978f683ac..f3bea4cf2467b6666fee195f11415ee6d2bc7260 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2311,6 +2311,128 @@ class TestSundryAPI(unittest.TestCase): self.assertTrue(out1.shape, [2, 3]) self.assertTrue(x1.grad.shape, [3, 3, 3]) + def test_multi_dot(self): + a = paddle.randn([4]) + a.stop_gradient = False + b = paddle.randn([4, 5]) + b.stop_gradient = False + c = paddle.randn([5]) + c.stop_gradient = False + + out = paddle.linalg.multi_dot([a, b, c]) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(a.grad.shape, [4]) + self.assertEqual(b.grad.shape, [4, 5]) + self.assertEqual(c.grad.shape, [5]) + + def test_dist(self): + x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32") + y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32") + x.stop_gradient = False + y.stop_gradient = False + out = paddle.dist(x, y, 0) + out.backward() + + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(1)) + self.assertEqual(x.grad.shape, [2, 2]) + self.assertEqual(y.grad.shape, [2, 2]) + + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + out.backward() + + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(12)) + self.assertEqual(x.grad.shape, [2, 2]) + + def test_cond(self): + pass + # def assert_shape(out): + # self.assertEqual(out.shape, []) + + # x = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + # x.stop_gradient = False + # p = 2 : use paddle.sum, paddle.max, paddle.min + # out = paddle.linalg.cond(x) + # assert_shape(out) + + # p = fro : use paddle.sum + # out_fro = paddle.linalg.cond(x, p='fro') + # assert_shape(out_fro) + + # p = nuc : use paddle.sum, paddle.max, paddle.min + # out_nuc = paddle.linalg.cond(x, p='nuc') + # assert_shape(out_nuc) + + # p in (-1, 1) : use paddle.sum, paddle.max, paddle.min + # out_1 = paddle.linalg.cond(x, p=1) + # assert_shape(out_1) + # out_minus_1 = paddle.linalg.cond(x, p=-1) + # assert_shape(out_minus_1) + + # p in (-2, 2) :use paddle.max, paddle.min + # out_2 = paddle.linalg.cond(x, p=2) + # assert_shape(out_2) + # out_minus_2 = paddle.linalg.cond(x, p=-2) + # assert_shape(out_minus_2) + + # p in (-inf, inf):use paddle.sum, paddle.max, paddle.min + # out_inf = paddle.linalg.cond(x, p=float("inf")) + # assert_shape(out_inf) + # out_minus_inf = paddle.linalg.cond(x, p=-float("inf")) + # assert_shape(out_minus_inf) + # out_minus_inf.backward() + # self.assertTrue(x.grad.shape, [3, 3]) + + # a = paddle.randn([2, 4, 4]) + # a.stop_gradient = False + # a_cond_fro = paddle.linalg.cond(a, p='fro') + # a_cond_fro.backward() + # self.assertEqual(len(a_cond_fro.shape), 1) + # self.assertEqual(a.grad.shape, [2, 4, 4]) + + def test_cov(self): + xt = paddle.randn((3, 4)) + xt.stop_gradient = False + xt_1 = paddle.randn((12,)) + xt_1.stop_gradient = False + + xt_out = paddle.linalg.cov(xt) + xt_out.retain_grads() + xt_out.backward() + self.assertEqual(xt_out.shape, [3, 3]) + self.assertEqual(xt.grad.shape, [3, 4]) + + xt_1_out = paddle.linalg.cov(xt_1) + xt_1.retain_grads() + xt_1_out.backward() + self.assertEqual(xt_1_out.shape, []) + self.assertEqual(xt_1.grad.shape, [12]) + + def test_det(self): + xt = paddle.randn([3, 3, 3]) + xt.stop_gradient = False + xt_1 = paddle.randn([3, 3]) + xt_1.stop_gradient = False + + xt_out = paddle.linalg.det(xt) + xt.retain_grads() + xt_out.backward() + self.assertEqual(xt_out.shape, [3]) + self.assertEqual(xt.grad.shape, [3, 3, 3]) + + xt_1_out = paddle.linalg.det(xt_1) + xt_1.retain_grads() + xt_1_out.backward() + self.assertEqual(xt_1_out.shape, []) + self.assertEqual(xt_1.grad.shape, [3, 3]) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -4122,6 +4244,100 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[0].shape, (2, 3)) self.assertEqual(res[1].shape, (3, 3, 3)) + @prog_scope() + def test_multi_dot(self): + a = paddle.randn([4]) + a.stop_gradient = False + b = paddle.randn([4, 5]) + b.stop_gradient = False + c = paddle.randn([5]) + c.stop_gradient = False + + out = paddle.linalg.multi_dot([a, b, c]) + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name] + ) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (4,)) + self.assertEqual(res[2].shape, (4, 5)) + self.assertEqual(res[3].shape, (5,)) + + @prog_scope() + def test_dist(self): + x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32") + y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32") + x.stop_gradient = False + y.stop_gradient = False + out = paddle.dist(x, y) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name, y.grad_name]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (2, 2)) + self.assertEqual(res[1].shape, (2, 2)) + np.testing.assert_array_equal(res[0], np.array(2).astype(np.float32)) + + @prog_scope() + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (2, 2)) + np.testing.assert_allclose(res[0], np.array(12)) + + @prog_scope() + def test_cond(self): + pass + # use paddle.sum, paddle.max, paddle.min + # x = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + # x.stop_gradient = False + # out = paddle.linalg.cond(x) + # paddle.static.append_backward(out) + + # prog = paddle.static.default_main_program() + # res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + # self.assertTrue(res[0].shape, ()) + # self.assertTrue(res[1].shape, (3, 3)) + # np.testing.assert_allclose(out, np.array(1.41421342)) + + @prog_scope() + def test_cov(self): + xt_1 = paddle.randn((12,)) + xt_1.stop_gradient = False + + out = paddle.linalg.cov(xt_1) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + + res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (12,)) + + @prog_scope() + def test_det(self): + xt_1 = paddle.randn((3, 3)) + xt_1.stop_gradient = False + + out = paddle.linalg.det(xt_1) + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, xt_1.grad_name]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (3, 3)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): @@ -4313,6 +4529,38 @@ class TestNoBackwardAPI(unittest.TestCase): self.assertEqual(inverse.shape, [1]) self.assertEqual(counts.shape, [1]) + def test_matrix_rank(self): + x = paddle.eye(10) + x.stop_gradient = False + out = paddle.linalg.matrix_rank(x) + + self.assertEqual(out.shape, []) + np.testing.assert_equal(out, np.array(10)) + + c = paddle.ones(shape=[3, 4, 5]) + c.stop_gradient = False + out_c = paddle.linalg.matrix_rank(c) + self.assertEqual(out_c.shape, [3]) + np.testing.assert_equal(out_c, np.array([1, 1, 1])) + + # 2D, tol->float : OUTPUT 0D + x_tol = paddle.eye(10) + x_tol.stop_gradient = False + out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1) + self.assertEqual(out_tol.shape, []) + + # 3D, tol->float : OUTPUT 1D + c_tol = paddle.ones(shape=[3, 4, 5]) + c_tol.stop_gradient = False + out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1) + self.assertEqual(out_c_tol.shape, [3]) + + tol_2 = paddle.randn([2]) + # 2D, tol->Tensor[1,2] : OUTPUT 1D + d = paddle.eye(10) + out_d = paddle.linalg.matrix_rank(d, tol=tol_2) + self.assertEqual(out_d.shape, [2]) + class TestNoBackwardAPIStatic(unittest.TestCase): def setUp(self): @@ -4547,6 +4795,51 @@ class TestNoBackwardAPIStatic(unittest.TestCase): self.assertEqual(res[2].shape, (1,)) self.assertEqual(res[3].shape, (1,)) + @prog_scope() + def test_static_matrix_rank(self): + # 2D : OUTPUT 0D + x = paddle.eye(10) + x.stop_gradient = False + out = paddle.linalg.matrix_rank(x) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + + # 3D : OUTPUT 1D + c = paddle.ones(shape=[3, 4, 5]) + c.stop_gradient = False + out_c = paddle.linalg.matrix_rank(c) + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run(prog, fetch_list=[out_c]) + self.assertEqual(res[0].shape, (3,)) + + # 2D, tol->float : OUTPUT 0D + x_tol = paddle.eye(10) + x_tol.stop_gradient = False + out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out_tol]) + self.assertEqual(res[0].shape, ()) + + # 3D, tol->float : OUTPUT 1D + c_tol = paddle.ones(shape=[3, 4, 5]) + c_tol.stop_gradient = False + out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1) + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run(prog, fetch_list=[out_c_tol]) + self.assertEqual(res[0].shape, (3,)) + + tol_2 = paddle.randn([2]) + # 2D, tol->Tensor[1,2] : OUTPUT 1D + d = paddle.eye(10) + out_d = paddle.linalg.matrix_rank(d, tol=tol_2) + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run(prog, fetch_list=[out_d]) + self.assertEqual(res[0].shape, (2,)) + unary_apis_with_complex_input = [ paddle.real,