From 9d90738cdfa2b1547645bee62a65fe1109abcc84 Mon Sep 17 00:00:00 2001 From: GGBond8488 <33050871+GGBond8488@users.noreply.github.com> Date: Mon, 24 Apr 2023 16:21:48 +0800 Subject: [PATCH] add 0D support for trace (#53208) * add 0D support for trace, test=allcase * fix trace gpu kernel 0d error, test=allcase * fix windows error, test=allcase --- paddle/phi/infermeta/unary.cc | 1 - paddle/phi/kernels/gpu/trace_kernel.cu | 5 +++- .../phi/kernels/impl/trace_grad_kernel_impl.h | 3 ++- .../tests/unittests/test_zero_dim_tensor.py | 24 +++++++++++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index bfe744446a9..ea27eba5130 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4402,7 +4402,6 @@ void TraceInferMeta( auto sizes = vectorize(x_dims); if (x_dims.size() == 2) { sizes.clear(); - sizes.push_back(1); } else { sizes.erase(sizes.begin() + std::max(dim1_, dim2_)); sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); diff --git a/paddle/phi/kernels/gpu/trace_kernel.cu b/paddle/phi/kernels/gpu/trace_kernel.cu index 671ca490e13..304bf778094 100644 --- a/paddle/phi/kernels/gpu/trace_kernel.cu +++ b/paddle/phi/kernels/gpu/trace_kernel.cu @@ -32,7 +32,10 @@ void TraceKernel(const Context& ctx, auto diag = funcs::Diagonal(ctx, &x, offset, axis1, axis2); if (diag.numel() > 0) { std::vector reduce_dims; - reduce_dims.push_back(out->dims().size()); + // Adapt to 0D output + auto out_dim_size = out->dims().size(); + if (out_dim_size == 0) out_dim_size = 1; + reduce_dims.push_back(out_dim_size); funcs::ReduceKernel>( ctx, diag, out, kps::IdentityFunctor(), reduce_dims); } else { diff --git a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h index 90a2327ef3e..1099f27f362 100644 --- a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h @@ -91,7 +91,8 @@ void TraceGradKernel(const Context& ctx, auto input_dims = in_grad->dims(); auto input_stride = phi::stride(input_dims); auto output_dims = out_grad.dims(); - auto output_stride = phi::stride(output_dims); + auto output_stride = output_dims.size() == 0 ? phi::DDim(output_dims) + : phi::stride(output_dims); auto* out_data = out_grad.data(); T* x_data = ctx.template Alloc(in_grad); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 7ea98f7c889..965bcae57d9 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2437,6 +2437,16 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(b.grad.shape, [4, 5]) self.assertEqual(c.grad.shape, [5]) + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + out.backward() + + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(12)) + self.assertEqual(x.grad.shape, [2, 2]) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -4426,6 +4436,20 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[2].shape, (4, 5)) self.assertEqual(res[3].shape, (5,)) + @prog_scope() + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (2, 2)) + np.testing.assert_allclose(res[0], np.array(12)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): -- GitLab