From c93451f42d4f3cc45f5f2a523ca912726d2ff8e8 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 18 Aug 2022 14:22:09 +0800 Subject: [PATCH] [OpAttr]Squeeze axes support Tensor (#45189) * [OpAttr]Squeeze axes support Tensor * add support_tensor * fix unittest * fix coverage --- paddle/fluid/operators/squeeze_op.cc | 3 +- paddle/phi/api/yaml/legacy_api.yaml | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 8 +- paddle/phi/infermeta/unary.cc | 37 ++++++--- paddle/phi/infermeta/unary.h | 10 ++- .../kernels/impl/squeeze_grad_kernel_impl.h | 3 +- paddle/phi/kernels/impl/squeeze_kernel_impl.h | 10 ++- paddle/phi/kernels/squeeze_grad_kernel.h | 3 +- paddle/phi/kernels/squeeze_kernel.h | 5 +- python/paddle/fluid/layers/nn.py | 14 +++- .../fluid/tests/unittests/test_squeeze2_op.py | 82 ++++++++++++++++++- python/paddle/tensor/manipulation.py | 15 +++- 12 files changed, 157 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index b3c70e2fe99..2d9e1e121c3 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -171,7 +171,8 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("axes", "(std::vector). List of integers," " indicating the dimensions to squeeze.") - .SetDefault({}); + .SetDefault({}) + .SupportTensor(); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 897944894d3..8f9cc3c4118 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2536,7 +2536,7 @@ backward : squared_l2_norm_grad - api : squeeze - args : (Tensor x, int[] axes) + args : (Tensor x, IntArray axes) output : Tensor(out), Tensor(xshape) infer_meta : func : SqueezeWithXShapeInferMeta diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index ec0f22b39f8..bb2ee9448d9 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2347,14 +2347,14 @@ func : squared_l2_norm_grad - backward_api : squeeze_double_grad - forward : squeeze_grad(Tensor xshape, Tensor grad_out, int[] axes) -> Tensor(grad_x) - args : (Tensor grad_x_grad, int[] axes) + forward : squeeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x) + args : (Tensor grad_x_grad, IntArray axes) output : Tensor(grad_out_grad) invoke: squeeze(grad_x_grad, axes) - backward_api : squeeze_grad - forward : squeeze(Tensor x, int[] axes) -> Tensor(out), Tensor(xshape) - args : (Tensor xshape, Tensor out_grad, int[] axes) + forward : squeeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape) + args : (Tensor xshape, Tensor out_grad, IntArray axes) output : Tensor(x_grad) infer_meta : func : KernelWithXShapeInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1382cb2e660..f8798982581 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3122,8 +3122,9 @@ void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out) { } void SqueezeInferMeta(const MetaTensor& x, - const std::vector& axes, - MetaTensor* out) { + const IntArray& axes, + MetaTensor* out, + MetaConfig config) { const auto& x_dims = x.dims(); // Check input tensor dims (<6) Eigen limit. PADDLE_ENFORCE_LE(x_dims.size(), @@ -3135,22 +3136,34 @@ void SqueezeInferMeta(const MetaTensor& x, x_dims.size(), x_dims)); - auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, false); - out->set_dims(out_dims); - if (x_dims[0] == out_dims[0]) { - // Only pass LoD when the first dimension of output and Input(X) - // are the same. - out->share_lod(x); + if (!config.is_runtime && axes.FromTensor()) { + // compile time infershape, set all elements to -1. + int output_size = x.dims().size() - axes.GetData().size(); + std::vector vec_out_dims(output_size, -1); + out->set_dims(phi::make_ddim(vec_out_dims)); + } else { + std::vector tmp; + tmp.reserve(axes.GetData().size()); + std::for_each(axes.GetData().begin(), + axes.GetData().end(), + [&tmp](const int64_t& t) { tmp.push_back(t); }); + auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, false); + out->set_dims(out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + out->share_lod(x); + } } - out->set_dtype(x.dtype()); } void SqueezeWithXShapeInferMeta(const MetaTensor& x, - const std::vector& axes, + const IntArray& axes, MetaTensor* out, - MetaTensor* xshape) { - SqueezeInferMeta(x, axes, out); + MetaTensor* xshape, + MetaConfig config) { + SqueezeInferMeta(x, axes, out, config); const auto& x_dims = x.dims(); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d81c8ea7a43..6350ef4a970 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -438,13 +438,15 @@ void SplitInferMeta(const MetaTensor& x_meta, void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out); void SqueezeInferMeta(const MetaTensor& x, - const std::vector& axes, - MetaTensor* out); + const IntArray& axes, + MetaTensor* out, + MetaConfig config = MetaConfig()); void SqueezeWithXShapeInferMeta(const MetaTensor& x, - const std::vector& axes, + const IntArray& axes, MetaTensor* out, - MetaTensor* xshape); + MetaTensor* xshape, + MetaConfig config = MetaConfig()); void StridedSliceRawInferMeta(const MetaTensor& x, const std::vector& axes, diff --git a/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h b/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h index 1e3dfd66ece..790de83050f 100644 --- a/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" @@ -21,7 +22,7 @@ template void SqueezeGradKernel(const Context& dev_ctx, const DenseTensor& xshape, const DenseTensor& dout, - const std::vector& axes, + const IntArray& axes, DenseTensor* dx) { auto xshape_dims = xshape.dims(); auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); diff --git a/paddle/phi/kernels/impl/squeeze_kernel_impl.h b/paddle/phi/kernels/impl/squeeze_kernel_impl.h index 156a71973a7..cb5eed521cd 100644 --- a/paddle/phi/kernels/impl/squeeze_kernel_impl.h +++ b/paddle/phi/kernels/impl/squeeze_kernel_impl.h @@ -21,20 +21,22 @@ namespace phi { template void SqueezeKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axes, + const IntArray& axes, DenseTensor* out) { auto x_dims = x.dims(); - auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true); + std::vector tmp(axes.GetData().begin(), axes.GetData().end()); + auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); + out->Resize(out_dims); dev_ctx.template Alloc(out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); - out->Resize(out_dims); + out->Resize(out_dims); // copy will reset the dims. } template void SqueezeWithXShapeKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axes, + const IntArray& axes, DenseTensor* out, DenseTensor* xshape) { SqueezeKernel(dev_ctx, x, axes, out); diff --git a/paddle/phi/kernels/squeeze_grad_kernel.h b/paddle/phi/kernels/squeeze_grad_kernel.h index 52b02bdbb95..8582012d2f6 100644 --- a/paddle/phi/kernels/squeeze_grad_kernel.h +++ b/paddle/phi/kernels/squeeze_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,6 +24,6 @@ template void SqueezeGradKernel(const Context& dev_ctx, const DenseTensor& xshape, const DenseTensor& dout, - const std::vector& axes, + const IntArray& axes, DenseTensor* dx); } // namespace phi diff --git a/paddle/phi/kernels/squeeze_kernel.h b/paddle/phi/kernels/squeeze_kernel.h index 1c6aeedbe51..7e5a1b0775a 100644 --- a/paddle/phi/kernels/squeeze_kernel.h +++ b/paddle/phi/kernels/squeeze_kernel.h @@ -15,6 +15,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -22,13 +23,13 @@ namespace phi { template void SqueezeKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axes, + const IntArray& axes, DenseTensor* out); template void SqueezeWithXShapeKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axes, + const IntArray& axes, DenseTensor* out, DenseTensor* xshape); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e1b0168dd18..53ab7ac05ca 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6612,12 +6612,22 @@ def squeeze(input, axes, name=None): 'float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64', 'complex64', 'complex128' ], 'squeeze') - check_type(axes, 'axis/axes', (list, tuple), 'squeeze') + check_type(axes, 'axis/axes', (list, tuple, Variable), 'squeeze') + + attrs = {} + if isinstance(axes, Variable): + axes.stop_gradient = True + attrs["axes"] = axes + elif isinstance(axes, (list, tuple)): + if utils._contain_var(axes): + attrs["axes"] = utils._convert_to_tensor_list(axes) + else: + attrs["axes"] = axes out = helper.create_variable_for_type_inference(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op(type="squeeze2", inputs={"X": input}, - attrs={"axes": axes}, + attrs=attrs, outputs={ "Out": out, "XShape": x_shape diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py index 711373165fd..9d3cd7fecc9 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -16,9 +16,12 @@ from __future__ import print_function import unittest import numpy as np - +import os from op_test import OpTest import paddle +from paddle.fluid.framework import program_guard, Program + +from test_attribute_var import UnittestBase paddle.enable_static() @@ -82,5 +85,82 @@ class TestSqueezeOp3(TestSqueezeOp): self.new_shape = (6, 5, 1, 4) +class TestSqueeze2AxesTensor(UnittestBase): + + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10] + # axes is a Variable + axes = paddle.assign([0, 2]) + out = paddle.squeeze(feat, axes) + out2 = paddle.fluid.layers.squeeze(feat, axes) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue("Var[" in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[feat, out, out2]) + self.assertEqual(res[0].shape, (1, 2, 1, 3, 10)) + self.assertEqual(res[1].shape, (2, 3, 10)) + self.assertEqual(res[2].shape, (2, 3, 10)) + + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + # Test for Inference Predictor + infer_out = self.infer_prog() + self.assertEqual(infer_out.shape, (2, 3, 10)) + + +class TestSqueeze2AxesTensorList(UnittestBase): + + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor') + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10] + # axes is a list[Variable] + axes = [ + paddle.full([1], 0, dtype='int32'), + paddle.full([1], 2, dtype='int32') + ] + out = paddle.squeeze(feat, axes) + out2 = paddle.fluid.layers.squeeze(feat, axes) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue("Vars[" in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[feat, out, out2]) + self.assertEqual(res[0].shape, (1, 2, 1, 3, 10)) + self.assertEqual(res[1].shape, (2, 3, 10)) + self.assertEqual(res[2].shape, (2, 3, 10)) + + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + # Test for Inference Predictor + infer_out = self.infer_prog() + self.assertEqual(infer_out.shape, (2, 3, 10)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index cd2e303fd9c..43837d03d3a 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2014,12 +2014,23 @@ def squeeze(x, axis=None, name=None): 'float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64', 'complex64', 'complex128' ], 'squeeze') - check_type(axes, 'axis/axes', (list, tuple), 'squeeze') + + check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'squeeze') + attrs = {} + if isinstance(axes, Variable): + axes.stop_gradient = True + attrs["axes"] = axes + elif isinstance(axes, (list, tuple)): + if utils._contain_var(axes): + attrs["axes"] = utils._convert_to_tensor_list(axes) + else: + attrs["axes"] = axes + out = helper.create_variable_for_type_inference(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op(type="squeeze2", inputs={"X": input}, - attrs={"axes": axes}, + attrs=attrs, outputs={ "Out": out, "XShape": x_shape -- GitLab