From 7a79fd8839b9927523d7789c5d2b4745e661b0cc Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Mon, 20 Mar 2023 13:21:57 +0800 Subject: [PATCH] Fix unsqueeze with empty axis bug (#51828) --- paddle/phi/infermeta/unary.cc | 11 +++-------- paddle/phi/kernels/funcs/unsqueeze.h | 2 +- paddle/phi/kernels/unsqueeze_kernel.cc | 7 +------ .../fluid/tests/unittests/test_unsqueeze_op.py | 16 ++++++++++++++++ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 42f8575450e..d8dd9759efa 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4791,15 +4791,10 @@ void UnsqueezeInferMeta(const MetaTensor& x, std::vector vec_out_dims(output_size, -1); out->set_dtype(x.dtype()); out->set_dims(phi::make_ddim(vec_out_dims)); - } else if (!axes.GetData().empty()) { - std::vector tmp; - tmp.reserve(axes.GetData().size()); - std::for_each(axes.GetData().begin(), - axes.GetData().end(), - [&tmp](const int64_t& t) { tmp.push_back(t); }); - auto out_dims = funcs::GetUnsqueezeShape(tmp, x_dims); + } else { + auto out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims); out->set_dims(out_dims); - if (x_dims[0] == out_dims[0]) { + if (x_dims.size() > 0 && x_dims[0] == out_dims[0]) { out->share_lod(x); } out->set_dtype(x.dtype()); diff --git a/paddle/phi/kernels/funcs/unsqueeze.h b/paddle/phi/kernels/funcs/unsqueeze.h index 6dd69c8212a..136fb9b2924 100644 --- a/paddle/phi/kernels/funcs/unsqueeze.h +++ b/paddle/phi/kernels/funcs/unsqueeze.h @@ -103,7 +103,7 @@ inline DDim GetOutputSqueezeShape(const std::vector squeeze_dims, return phi::make_ddim(output_shape); } -inline DDim GetUnsqueezeShape(const std::vector unsqz_dims, +inline DDim GetUnsqueezeShape(const std::vector unsqz_dims, const DDim& in_dims) { int output_size = in_dims.size() + static_cast(unsqz_dims.size()); int cur_output_size = in_dims.size(); diff --git a/paddle/phi/kernels/unsqueeze_kernel.cc b/paddle/phi/kernels/unsqueeze_kernel.cc index 159e7a4ce17..4008d7883d4 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_kernel.cc @@ -28,12 +28,7 @@ void UnsqueezeInferKernel(const Context& dev_ctx, auto x_dims = x.dims(); auto out_dims = out->dims(); if (axes.FromTensor()) { - std::vector tmp; - tmp.reserve(axes.GetData().size()); - std::for_each(axes.GetData().begin(), - axes.GetData().end(), - [&tmp](const int64_t& t) { tmp.push_back(t); }); - out_dims = funcs::GetUnsqueezeShape(tmp, x_dims); + out_dims = funcs::GetUnsqueezeShape(axes.GetData(), x_dims); } out->Resize(out_dims); dev_ctx.template Alloc(out); diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index fdb68a27795..9e65d1d5117 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -108,6 +108,22 @@ class TestUnsqueezeOp4(TestUnsqueezeOp): self.new_shape = (10, 1, 1, 2, 5, 1) +# axis is empty, x is ND +class TestUnsqueezeOp5(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = () + self.axes = () + self.new_shape = () + + +# axis is empty, x is 0D +class TestUnsqueezeOp6(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = () + self.new_shape = (10, 2, 5) + + class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp): def init_test_case(self): self.ori_shape = () -- GitLab