未验证 提交 89ff0d59 编写于 作者: warrentdrew's avatar warrentdrew 提交者: GitHub

add composite rules for squeeze op (#51539)

* add composite rule for squeeze

* fix pre commit

* fix pre commit

* simplify rules

* arrange code

* fix int axis

* simplify squeeze axis rules

* bugfix

* fix pre commit
上级 2c543193
......@@ -1206,6 +1206,7 @@ set(TEST_CINN_OPS
test_elementwise_div_op
test_elementwise_mul_op
test_gather_nd_op
test_squeeze2_op
test_elementwise_pow_op
test_elementwise_max_op
test_transpose_op
......
......@@ -29,6 +29,7 @@ paddle.enable_static()
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze2"
self.prim_op_type = "comp"
self.python_api = paddle.squeeze
self.python_out_sig = [
"Out"
......@@ -42,10 +43,12 @@ class TestSqueezeOp(OpTest):
}
def test_check_output(self):
self.check_output(no_check_set=['XShape'], check_eager=True)
self.check_output(
no_check_set=['XShape'], check_eager=True, check_prim=True
)
def test_check_grad(self):
self.check_grad(["X"], "Out", check_eager=True)
self.check_grad(["X"], "Out", check_eager=True, check_prim=True)
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
......@@ -66,6 +69,22 @@ class TestSqueezeOp1(TestSqueezeOp):
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def setUp(self):
self.op_type = "squeeze2"
self.prim_op_type = "comp"
self.python_api = paddle.squeeze
self.enable_cinn = False
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float64")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float64"),
}
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
......
......@@ -437,6 +437,29 @@ def fill_any_like(x, fill_value, dtype, place=None):
return val
@REGISTER_COMPOSITE('squeeze2')
def squeeze2_composite(x, axis):
"""define composite rule of squeeze"""
"""
canonicalize dim within range 0 to rank and
determine new shape after squeeze op
if axis not specified, remove all dims equal to 1
otherwise, remove dims equal to 1 in axis
axis can only be list, not int
"""
rank = len(x.shape)
if len(axis) == 0:
dims = set(range(rank))
else:
dims = set([ax % rank for ax in axis])
new_shape = []
for d, s in enumerate(x.shape):
if not (s == 1 and (d in dims)):
new_shape.append(s)
out = reshape(x, new_shape)
return [out, None]
@REGISTER_COMPOSITE('sqrt')
def sqrt_composite(x):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册