未验证 提交 9fb2c603 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN support unary op, fix test_activation_op (#54216)

* [0D-Tensor] Support unary op, fix test_activation_op

* resolve conflict

* restore TestGelu

* polish codes according to comments
上级 f4a7b162
......@@ -77,7 +77,42 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"greater_than",
"greater_equal",
"less_than",
"less_equal"};
"less_equal",
"tanh",
"relu",
"gelu",
"sigmoid",
"exp",
"erf",
"rsqrt",
"log",
"log2",
"log10",
"floor",
"ceil",
"round",
"trunc",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"acosh",
"atanh",
"isnan",
"isfinite",
"isinf",
"negative",
"sign",
"abs",
"reciprocal",
"logical_not",
"bitwise_not"};
std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op
// test
......
......@@ -143,9 +143,6 @@ class TestExpPrim_ZeroDim(TestExpFp32_Prim):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestExpm1(TestActivation):
def setUp(self):
......@@ -277,9 +274,6 @@ class TestSigmoid_ZeroDim(TestSigmoid):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
......@@ -362,9 +356,6 @@ class TestSilu_ZeroDim(TestSilu):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
......@@ -527,9 +518,6 @@ class TestTanh_ZeroDim(TestTanh):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestTanhAPI(unittest.TestCase):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
......@@ -1237,9 +1225,6 @@ class TestSqrt_ZeroDim(TestSqrt):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
......@@ -1428,9 +1413,6 @@ class TestAbs_ZeroDim(TestAbs):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestCeil(TestActivation):
def setUp(self):
......@@ -1509,9 +1491,6 @@ class TestFloor_ZeroDim(TestFloor):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestCos(TestActivation):
def setUp(self):
......@@ -1521,8 +1500,7 @@ class TestCos(TestActivation):
self.prim_op_type = "prim"
self.init_dtype()
self.init_shape()
# prim not support now
self.enable_cinn = False
self.if_enable_cinn()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
......@@ -1539,6 +1517,9 @@ class TestCos(TestActivation):
return
self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
class TestCos_ZeroDim(TestCos):
def init_shape(self):
......@@ -1659,8 +1640,7 @@ class TestSin(TestActivation, TestParameter):
self.prim_op_type = "prim"
self.init_dtype()
self.init_shape()
# prim not support now
self.enable_cinn = False
self.if_enable_cinn()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
......@@ -1677,6 +1657,9 @@ class TestSin(TestActivation, TestParameter):
return
self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
class TestSin_ZeroDim(TestSin):
def init_shape(self):
......@@ -1862,9 +1845,6 @@ class TestRelu_ZeroDim(TestRelu):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu
......@@ -2141,9 +2121,6 @@ class TestGelu_ZeroDim(TestGelu):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestGELUAPI(unittest.TestCase):
# test paddle.nn.GELU, paddle.nn.functional.gelu
......@@ -2396,7 +2373,6 @@ class TestHardSwish(TestActivation):
self.outputs = {'Out': out}
self.convert_input_output()
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.enable_cinn = False
def init_shape(self):
self.shape = [10, 12]
......@@ -2417,10 +2393,6 @@ class TestHardSwish(TestActivation):
class TestHardSwish_ZeroDim(TestHardSwish):
def setUp(self):
super().setUp()
self.enable_cinn = False
def init_shape(self):
self.shape = []
......@@ -2831,9 +2803,6 @@ class TestLog_ZeroDim(TestLog):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestLog2(TestActivation):
def setUp(self):
......@@ -3131,9 +3100,6 @@ class TestPow_ZeroDim(TestPow):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestPow_factor_tensor(TestActivation):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册