未验证 提交 ed19d37f 编写于 作者: M mengziheng 提交者: GitHub

Add Unsqueeze op composite rule (#51527)

* first test

* add unsqueeze_op
上级 b76ab792
......@@ -1215,7 +1215,8 @@ set(TEST_CINN_OPS
test_elementwise_pow_op
test_transpose_op
test_reshape_op
test_mean_op)
test_mean_op
test_unsqueeze2_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -36,9 +36,12 @@ class TestUnsqueezeOp(OpTest):
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float64"),
}
self.prim_op_type = "comp"
def test_check_output(self):
self.check_output(no_check_set=["XShape"], check_eager=True)
self.check_output(
no_check_set=["XShape"], check_eager=True, check_prim=True
)
def test_check_grad(self):
self.check_grad(["X"], "Out", check_eager=True)
......@@ -89,6 +92,7 @@ class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
self.ori_shape = ()
self.axes = (-1,)
self.new_shape = 1
self.enable_cinn = False
class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
......@@ -96,6 +100,7 @@ class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
self.ori_shape = ()
self.axes = (-1, 1)
self.new_shape = (1, 1)
self.enable_cinn = False
class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
......@@ -103,6 +108,7 @@ class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
self.ori_shape = ()
self.axes = (0, 1, 2)
self.new_shape = (1, 1, 1)
self.enable_cinn = False
# axes is a list(with tensor)
......
......@@ -371,3 +371,23 @@ def relu_composite(x):
"""define composite rule of op relu."""
# relu(x) = max(x, 0)
return maximum(x, zeros_like(x))
@REGISTER_COMPOSITE('unsqueeze2')
def unsqueeze_composite(x, axis):
"""define composite rule of op unsqueeze"""
"""using reshape to implement unsqueeze op"""
x_shape = list(x.shape)
axis_list = list(axis)
for i in axis_list:
if i < 0:
i += len(x_shape) + 1
x_shape = (
x_shape[:i]
+ [
1,
]
+ x_shape[i:]
)
out = reshape(x, x_shape)
return [out, None]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册