diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6bde7643a98a971b1bc723f7d7a62116301fca5c..ddb96476f5c50d06dcf68d79fc852a6f67f5f530 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1206,6 +1206,7 @@ set(TEST_CINN_OPS test_elementwise_div_op test_elementwise_mul_op test_gather_nd_op + test_squeeze2_op test_elementwise_pow_op test_elementwise_max_op test_transpose_op diff --git a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py index 166864bd5e3df31eb90dd010f109af10a6fbd73f..31c22ef123f3e8e2c474ca07035dc701b1f1bef1 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze2_op.py @@ -29,6 +29,7 @@ paddle.enable_static() class TestSqueezeOp(OpTest): def setUp(self): self.op_type = "squeeze2" + self.prim_op_type = "comp" self.python_api = paddle.squeeze self.python_out_sig = [ "Out" @@ -42,10 +43,12 @@ class TestSqueezeOp(OpTest): } def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_eager=True) + self.check_output( + no_check_set=['XShape'], check_eager=True, check_prim=True + ) def test_check_grad(self): - self.check_grad(["X"], "Out", check_eager=True) + self.check_grad(["X"], "Out", check_eager=True, check_prim=True) def init_test_case(self): self.ori_shape = (1, 3, 1, 40) @@ -66,6 +69,22 @@ class TestSqueezeOp1(TestSqueezeOp): # Correct: No axes input. class TestSqueezeOp2(TestSqueezeOp): + def setUp(self): + self.op_type = "squeeze2" + self.prim_op_type = "comp" + self.python_api = paddle.squeeze + self.enable_cinn = False + self.python_out_sig = [ + "Out" + ] # python out sig is customized output signature. + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float64"), + } + def init_test_case(self): self.ori_shape = (1, 20, 1, 5) self.axes = () diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index c479416357dbeba9e210e7aabcc75a6d4c6d7c8b..be513eeb789b6980234b2c204a31586f4eb5dbe2 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -437,6 +437,29 @@ def fill_any_like(x, fill_value, dtype, place=None): return val +@REGISTER_COMPOSITE('squeeze2') +def squeeze2_composite(x, axis): + """define composite rule of squeeze""" + """ + canonicalize dim within range 0 to rank and + determine new shape after squeeze op + if axis not specified, remove all dims equal to 1 + otherwise, remove dims equal to 1 in axis + axis can only be list, not int + """ + rank = len(x.shape) + if len(axis) == 0: + dims = set(range(rank)) + else: + dims = set([ax % rank for ax in axis]) + new_shape = [] + for d, s in enumerate(x.shape): + if not (s == 1 and (d in dims)): + new_shape.append(s) + out = reshape(x, new_shape) + return [out, None] + + @REGISTER_COMPOSITE('sqrt') def sqrt_composite(x): """