未验证 提交 ed19d37f 编写于 作者: M mengziheng 提交者: GitHub

Add Unsqueeze op composite rule (#51527)

* first test

* add unsqueeze_op
上级 b76ab792
...@@ -1215,7 +1215,8 @@ set(TEST_CINN_OPS ...@@ -1215,7 +1215,8 @@ set(TEST_CINN_OPS
test_elementwise_pow_op test_elementwise_pow_op
test_transpose_op test_transpose_op
test_reshape_op test_reshape_op
test_mean_op) test_mean_op
test_unsqueeze2_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -36,9 +36,12 @@ class TestUnsqueezeOp(OpTest): ...@@ -36,9 +36,12 @@ class TestUnsqueezeOp(OpTest):
"Out": self.inputs["X"].reshape(self.new_shape), "Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float64"), "XShape": np.random.random(self.ori_shape).astype("float64"),
} }
self.prim_op_type = "comp"
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=["XShape"], check_eager=True) self.check_output(
no_check_set=["XShape"], check_eager=True, check_prim=True
)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out", check_eager=True) self.check_grad(["X"], "Out", check_eager=True)
...@@ -89,6 +92,7 @@ class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp): ...@@ -89,6 +92,7 @@ class TestUnsqueezeOp_ZeroDim1(TestUnsqueezeOp):
self.ori_shape = () self.ori_shape = ()
self.axes = (-1,) self.axes = (-1,)
self.new_shape = 1 self.new_shape = 1
self.enable_cinn = False
class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp): class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
...@@ -96,6 +100,7 @@ class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp): ...@@ -96,6 +100,7 @@ class TestUnsqueezeOp_ZeroDim2(TestUnsqueezeOp):
self.ori_shape = () self.ori_shape = ()
self.axes = (-1, 1) self.axes = (-1, 1)
self.new_shape = (1, 1) self.new_shape = (1, 1)
self.enable_cinn = False
class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp): class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
...@@ -103,6 +108,7 @@ class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp): ...@@ -103,6 +108,7 @@ class TestUnsqueezeOp_ZeroDim3(TestUnsqueezeOp):
self.ori_shape = () self.ori_shape = ()
self.axes = (0, 1, 2) self.axes = (0, 1, 2)
self.new_shape = (1, 1, 1) self.new_shape = (1, 1, 1)
self.enable_cinn = False
# axes is a list(with tensor) # axes is a list(with tensor)
......
...@@ -371,3 +371,23 @@ def relu_composite(x): ...@@ -371,3 +371,23 @@ def relu_composite(x):
"""define composite rule of op relu.""" """define composite rule of op relu."""
# relu(x) = max(x, 0) # relu(x) = max(x, 0)
return maximum(x, zeros_like(x)) return maximum(x, zeros_like(x))
@REGISTER_COMPOSITE('unsqueeze2')
def unsqueeze_composite(x, axis):
"""define composite rule of op unsqueeze"""
"""using reshape to implement unsqueeze op"""
x_shape = list(x.shape)
axis_list = list(axis)
for i in axis_list:
if i < 0:
i += len(x_shape) + 1
x_shape = (
x_shape[:i]
+ [
1,
]
+ x_shape[i:]
)
out = reshape(x, x_shape)
return [out, None]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册