diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 62e45468abb3cda01f83ce0961072c5138ad944a..d950da6a75885cf6a2ad1b88acf7ed8e118a4020 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -28,9 +28,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { "Output(Out) of UnsqueezeOp should not be null."); const auto &axes = ctx->Attrs().Get>("axes"); - PADDLE_ENFORCE(!axes.empty(), - "The unsqueeze axes information must be set by Attr(axes)."); - const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, @@ -123,6 +120,9 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "(std::vector). List of positive integers," " indicate the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { + PADDLE_ENFORCE( + !axes.empty(), + "The unsqueeze axes information must be set by Attr(axes)."); // Validity Check: axes dims (<6). PADDLE_ENFORCE(static_cast(axes.size()) < 6, "Invalid dimensions, dynamic dimensions should within " diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 62dc6fcb9e691c75d179ca862a2023795cde6394..d19d4e525a8a77379e66352c28c61382d4207080 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -24,7 +24,7 @@ class TestUnsqueezeOp(OpTest): self.init_test_case() self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} - self.attrs = {"axes": self.axes, "inplace": False} + self.init_attrs() self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} def test_check_output(self): @@ -38,6 +38,9 @@ class TestUnsqueezeOp(OpTest): self.axes = (1, 2) self.new_shape = (3, 1, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + # Correct: Single input index. class TestUnsqueezeOp1(TestUnsqueezeOp): @@ -70,6 +73,9 @@ class TestUnsqueezeOpInplace1(TestUnsqueezeOp): self.axes = (0, 2) self.new_shape = (1, 3, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + # Correct: Inplace. There is mins index. class TestUnsqueezeOpInplace2(TestUnsqueezeOp): @@ -78,6 +84,9 @@ class TestUnsqueezeOpInplace2(TestUnsqueezeOp): self.axes = (0, -2) self.new_shape = (1, 3, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + # Correct: Inplace. There is duplicated axis. class TestUnsqueezeOpInplace3(TestUnsqueezeOp): @@ -86,6 +95,9 @@ class TestUnsqueezeOpInplace3(TestUnsqueezeOp): self.axes = (0, 3, 3) self.new_shape = (1, 3, 2, 1, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + if __name__ == "__main__": unittest.main()