diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 768c33d4f4ad6990d079614c11b709666119d16b..c3b96b813b8c3b7e90b1b65555e95ee60a3ee6c1 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -563,8 +563,8 @@ void DiagInferMeta(const MetaTensor& x, MetaTensor* out) { auto x_dims = x.dims(); - if (x_dims.size() == 1UL) { - int64_t size_ = x_dims[0] + std::abs(offset); + if (x_dims.size() <= 1) { + int64_t size_ = (x_dims.size() == 1UL ? x_dims[0] : 1) + std::abs(offset); out->set_dims({size_, size_}); out->set_dtype(x.dtype()); } else if (x_dims.size() == 2UL) { diff --git a/paddle/phi/kernels/cpu/diag_grad_kernel.cc b/paddle/phi/kernels/cpu/diag_grad_kernel.cc index 616ea753ef1bac40b0f33ef5eb35e8607a6f7936..13d3d679ff006aa61bb10b1c07c116335f6416eb 100644 --- a/paddle/phi/kernels/cpu/diag_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/diag_grad_kernel.cc @@ -32,9 +32,9 @@ void DiagGradKernel(const Context& dev_ctx, auto dx_dims = x_grad->dims(); auto dout_dims = out_grad.dims(); - if (dx_dims.size() == 1) { - auto dx_length = dx_dims[0]; - int dx_stride = phi::funcs::ComputeStride(0, dx_dims); + if (dx_dims.size() <= 1) { + auto dx_length = (dx_dims.size() == 1 ? dx_dims[0] : int64_t(1)); + int dx_stride = 1; auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims); auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims); diff --git a/paddle/phi/kernels/cpu/diag_kernel.cc b/paddle/phi/kernels/cpu/diag_kernel.cc index 4b060f0372a5bf50d9378239dae635e5723d0c7a..1576d80b15206b9fd3c1a657e6858ce8d7359e69 100644 --- a/paddle/phi/kernels/cpu/diag_kernel.cc +++ b/paddle/phi/kernels/cpu/diag_kernel.cc @@ -33,12 +33,12 @@ void DiagKernel(const Context& dev_ctx, auto out_dims = out->dims(); int64_t i; - if (x_dims.size() == 1) { + if (x_dims.size() <= 1) { phi::funcs::SetConstant set_padding_value; set_padding_value(dev_ctx, out, static_cast(padding_value)); - auto x_length = x_dims[0]; - const int& x_stride = phi::funcs::ComputeStride(0, x_dims); + auto x_length = (x_dims.size() == 1UL ? x_dims[0] : int64_t(1)); + const int& x_stride = 1; auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims); auto out_stride_1 = phi::funcs::ComputeStride(1, out_dims); diff --git a/paddle/phi/kernels/gpu/diag_grad_kernel.cu b/paddle/phi/kernels/gpu/diag_grad_kernel.cu index 5a579ecc27b7ffd9a691c1bd3d0eb9c7a013fb83..39ac78dae021618097cd5860db1d2d6cd0b03ab1 100644 --- a/paddle/phi/kernels/gpu/diag_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/diag_grad_kernel.cu @@ -73,10 +73,10 @@ void DiagGradKernel(const Context& dev_ctx, return std::tuple{block_size, grid_size}; }; - if (dx_dims.size() == 1) { - auto dx_length = dx_dims[0]; + if (dx_dims.size() <= 1) { + auto dx_length = (dx_dims.size() == 1 ? dx_dims[0] : int64_t(1)); auto size = (offset > 0) ? dx_length + offset : dx_length - offset; - int dx_stride = phi::funcs::ComputeStride(0, dx_dims); + int dx_stride = 1; if (size > 0) { auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims); auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims); diff --git a/paddle/phi/kernels/gpu/diag_kernel.cu b/paddle/phi/kernels/gpu/diag_kernel.cu index 95d3d3365d91be61013e2016d06334f0498d866a..588bb17b79a0d7af9d88c99337449da295863b2a 100644 --- a/paddle/phi/kernels/gpu/diag_kernel.cu +++ b/paddle/phi/kernels/gpu/diag_kernel.cu @@ -77,13 +77,13 @@ void DiagKernel(const Context& dev_ctx, return std::tuple{block_size, grid_size}; }; - if (x_dims.size() == 1) { + if (x_dims.size() <= 1) { phi::funcs::SetConstant set_padding_value; set_padding_value(dev_ctx, out, static_cast(padding_value)); - auto x_length = x_dims[0]; + auto x_length = (x_dims.size() == 1UL ? x_dims[0] : int64_t(1)); auto size = (offset > 0) ? x_length + offset : x_length - offset; - const int& x_stride = phi::funcs::ComputeStride(0, x_dims); + const int& x_stride = 1; if (size > 0) { const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims); const auto& out_stride_1 = phi::funcs::ComputeStride(1, out_dims); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index e7381350624b9b374d7eeff0a9d865c32b5277e6..eae5528fba244c5ae3539acba4383d3fcee354f9 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -653,6 +653,35 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) self.assertEqual(out.grad.shape, [2, 3]) + def test_diagflat(self): + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x3 = paddle.rand([]) + + x1.stop_gradient = False + x2.stop_gradient = False + x3.stop_gradient = False + + out1 = paddle.diagflat(x1, 1) + out2 = paddle.diagflat(x2, -1) + out3 = paddle.diagflat(x3, 0) + + out1.backward() + out2.backward() + out3.backward() + + self.assertEqual(out1.shape, [2, 2]) + self.assertEqual(out2.shape, [2, 2]) + self.assertEqual(out3.shape, [1, 1]) + + self.assertEqual(out1.grad.shape, [2, 2]) + self.assertEqual(out2.grad.shape, [2, 2]) + self.assertEqual(out3.grad.shape, [1, 1]) + + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(x3.grad.shape, []) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -796,6 +825,26 @@ class TestSundryAPIStatic(unittest.TestCase): for i in range(3): self.assertEqual(res[0][1][i], 4) + @prog_scope() + def test_diagflat(self): + x1 = paddle.rand([]) + out1 = paddle.diagflat(x1, 1) + paddle.static.append_backward(out1) + + x2 = paddle.rand([]) + out2 = paddle.diagflat(x2, -1) + paddle.static.append_backward(out2) + + x3 = paddle.rand([]) + out3 = paddle.diagflat(x3) + paddle.static.append_backward(out3) + + prog = paddle.static.default_main_program() + res1, res2, res3 = self.exe.run(prog, fetch_list=[out1, out2, out3]) + self.assertEqual(res1.shape, (2, 2)) + self.assertEqual(res2.shape, (2, 2)) + self.assertEqual(res3.shape, (1, 1)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index b07043689f7feffba1e23d06284b77fed4213514..a6f91e5df4c66ecf9e9fb535aef350b128fb3bd3 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -475,6 +475,35 @@ class TestSundryAPI(unittest.TestCase): for i in range(3): self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) + def test_diagflat(self): + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x3 = paddle.rand([]) + + x1.stop_gradient = False + x2.stop_gradient = False + x3.stop_gradient = False + + out1 = paddle.diagflat(x1, 1) + out2 = paddle.diagflat(x2, -1) + out3 = paddle.diagflat(x3, 0) + + out1.backward() + out2.backward() + out3.backward() + + self.assertEqual(out1.shape, [2, 2]) + self.assertEqual(out2.shape, [2, 2]) + self.assertEqual(out3.shape, [1, 1]) + + self.assertEqual(out1.grad.shape, [2, 2]) + self.assertEqual(out2.grad.shape, [2, 2]) + self.assertEqual(out3.grad.shape, [1, 1]) + + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(x3.grad.shape, []) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 134e27eef9df6e82ced105cbb0d443885ee7a301..d597ff6a1317f77b0fbaa620480216e95d450975 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1479,7 +1479,7 @@ def diagflat(x, offset=0, name=None): """ padding_value = 0 if in_dygraph_mode(): - if len(x.shape) == 1: + if len(x.shape) <= 1: return _C_ops.diag(x, offset, padding_value) else: y = _C_ops.flatten(x, 0, -1) @@ -1509,7 +1509,7 @@ def diagflat(x, offset=0, name=None): out1_shape = helper.create_variable_for_type_inference(x.dtype) out2 = helper.create_variable_for_type_inference(dtype=x.dtype) - if len(x.shape) == 1: + if len(x.shape) <= 1: helper.append_op( type='diag_v2', inputs={'X': x},