diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 8f453b059fdc25496a73c39c65ae074e28b63508..639480aba41783a5e830270733e175f502087b8f 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -33,11 +33,12 @@ class SqueezeOp : public framework::OperatorWithKernel { "Output(Out) of SqueezeOp should not be null."); const auto& x_dims = ctx->GetInputDim("X"); - // TODO(chenweihang): need check input tensor dims (<9). + // Check input tensor dims (<9). + PADDLE_ENFORCE(x_dims.size() <= 9, + "Invalid dimnesions, dynamic dimensions must have " + "between [1, 9] dimensions."); const auto& axes = ctx->Attrs().Get>("axes"); - // TODO(chenweihang): need check axes is valid. - // PADDLE_ENFORCE(); for (int a : axes) { PADDLE_ENFORCE_LT(a, x_dims.size(), "The axis must be less than input tensor's rank."); @@ -45,7 +46,12 @@ class SqueezeOp : public framework::OperatorWithKernel { auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); - // TODO(chenweihang): need other check. + // TODO(chenweihang): This share option is necessary? + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } } static framework::DDim GetOutputShape(const std::vector squeeze_dims, @@ -67,12 +73,17 @@ class SqueezeOp : public framework::OperatorWithKernel { for (int idx = 0; idx < num_squeeze_dims; ++idx) { int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() : squeeze_dims[idx]; - // TODO(chenweihang): shoude use PADALE_ENFORCE ? or if. - PADDLE_ENFORCE_GE(current, 0, "Invalid axis is given."); - PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given."); - PADDLE_ENFORCE_EQ(in_dims[current], 1, "Invalid axis is given."); - - if (!(should_squeeze[current])) ++cnt_squeezed_dims; + // Check current index. + PADDLE_ENFORCE(current >= 0, + "Invalid axis, negative axis is out of range."); + // PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given."); + PADDLE_ENFORCE( + in_dims[current] == 1, + "Invalid axis index, the axis will be squeezed should be 1."); + + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } should_squeeze[current] = true; } } @@ -92,13 +103,14 @@ class SqueezeOp : public framework::OperatorWithKernel { class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor), Tensors with at least max(dims) dimensions."); - AddOutput("Out", "(Tensor), Reshaped tensor with same data as input."); + AddInput("X", "(Tensor). The input tensor of squeeze operator."); + AddOutput("Out", "(Tensor). The output tensor of squeeze operator."); AddAttr>("axes", - "List of positive integers," - " indicate the dimensions to squeeze."); + "(std::vector). List of positive integers," + " indicate the dimensions to squeeze.") + .SetDefault({}); AddAttr("inplace", - "(default: false) Change the source tensor's shape without " + "(default: false) Squeeze the source tensor's shape without " "memory copy. When Attr(inplace) is set true, the output " "tensor shares memory with Input(X), otherwise, a new output " "tensor is created, and its data are copied from Input(x).") @@ -110,6 +122,21 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { Takes a parameter axes with a list of axes to squeeze. If axes is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised. + + Examples: + Case 1: + Given + X.shape = (1, 3, 1, 5) + and + axes = [0] + we get: + Out.shape = (3, 1, 5) + + Case 2: + Given + X.shape = (1, 3, 1, 5) + we get: + Out.shape = (3, 5) )DOC"); } }; @@ -120,9 +147,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SqueezeOp should not be null."); + "Input(X) of SqueezeGradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Output(Out@GRAD/) of SqueezeOp should not be null."); + "Output(Out@GRAD) of SqueezeGradOp should not be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } diff --git a/paddle/fluid/operators/squeeze_op.h b/paddle/fluid/operators/squeeze_op.h index ce6f40e7a4f8b276869ff54d83bbe33ed76504c0..44ef324c7dc5a702fcc8e7846f3870a94c4aa953 100644 --- a/paddle/fluid/operators/squeeze_op.h +++ b/paddle/fluid/operators/squeeze_op.h @@ -33,7 +33,6 @@ class SqueezeKernel : public framework::OpKernel { framework::DDim out_dims = out->dims(); - // TODO(chenweihang): Where is this attr be add. bool inplace = ctx.Attr("inplace"); out->Resize(out_dims); if (!inplace) { diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..58c87ea3c16a8a0d88814093667eb4129af4d968 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -0,0 +1,174 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestSqueezeOp1(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5) + axes = (0, 2) + new_shape = (3, 5) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: There is mins axis. +class TestSqueezeOp2(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5) + axes = (0, -2) + new_shape = (3, 5) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: No axes input. +class TestSqueezeOp3(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5) + axes = () + new_shape = (3, 5) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp4(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5, 1, 4, 1) + axes = (2, 6) + new_shape = (1, 3, 5, 1, 4) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. +class TestSqueezeOpInplace1(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5) + axes = (0, 2) + new_shape = (3, 5) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inplace": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. There is mins axis. +class TestSqueezeOpInplace2(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5) + axes = (0, -2) + new_shape = (3, 5) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. No axes input. +class TestSqueezeOpInplace3(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5) + axes = () + new_shape = (3, 5) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inpalce. Just part of axes be squeezed. +class TestSqueezeOpInplace4(OpTest): + def setUp(self): + ori_shape = (1, 3, 1, 5, 1, 4, 1) + axes = (2, 6) + new_shape = (1, 3, 5, 1, 4) + + self.op_type = "squeeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == "__main__": + unittest.main()