未验证 提交 9fb2c603 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN support unary op, fix test_activation_op (#54216)

* [0D-Tensor] Support unary op, fix test_activation_op

* resolve conflict

* restore TestGelu

* polish codes according to comments
上级 f4a7b162
...@@ -77,7 +77,42 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { ...@@ -77,7 +77,42 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"greater_than", "greater_than",
"greater_equal", "greater_equal",
"less_than", "less_than",
"less_equal"}; "less_equal",
"tanh",
"relu",
"gelu",
"sigmoid",
"exp",
"erf",
"rsqrt",
"log",
"log2",
"log10",
"floor",
"ceil",
"round",
"trunc",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"asin",
"acos",
"atan",
"asinh",
"acosh",
"atanh",
"isnan",
"isfinite",
"isinf",
"negative",
"sign",
"abs",
"reciprocal",
"logical_not",
"bitwise_not"};
std::unordered_set<std::string> white_tensor_name; std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op // enable white_op_list only when graph_node_size = 1, which means single op
// test // test
......
...@@ -143,9 +143,6 @@ class TestExpPrim_ZeroDim(TestExpFp32_Prim): ...@@ -143,9 +143,6 @@ class TestExpPrim_ZeroDim(TestExpFp32_Prim):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestExpm1(TestActivation): class TestExpm1(TestActivation):
def setUp(self): def setUp(self):
...@@ -277,9 +274,6 @@ class TestSigmoid_ZeroDim(TestSigmoid): ...@@ -277,9 +274,6 @@ class TestSigmoid_ZeroDim(TestSigmoid):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
...@@ -362,9 +356,6 @@ class TestSilu_ZeroDim(TestSilu): ...@@ -362,9 +356,6 @@ class TestSilu_ZeroDim(TestSilu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestSiluAPI(unittest.TestCase): class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu # test paddle.nn.Silu, paddle.nn.functional.silu
...@@ -527,9 +518,6 @@ class TestTanh_ZeroDim(TestTanh): ...@@ -527,9 +518,6 @@ class TestTanh_ZeroDim(TestTanh):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestTanhAPI(unittest.TestCase): class TestTanhAPI(unittest.TestCase):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh # test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
...@@ -1237,9 +1225,6 @@ class TestSqrt_ZeroDim(TestSqrt): ...@@ -1237,9 +1225,6 @@ class TestSqrt_ZeroDim(TestSqrt):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
...@@ -1428,9 +1413,6 @@ class TestAbs_ZeroDim(TestAbs): ...@@ -1428,9 +1413,6 @@ class TestAbs_ZeroDim(TestAbs):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestCeil(TestActivation): class TestCeil(TestActivation):
def setUp(self): def setUp(self):
...@@ -1509,9 +1491,6 @@ class TestFloor_ZeroDim(TestFloor): ...@@ -1509,9 +1491,6 @@ class TestFloor_ZeroDim(TestFloor):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestCos(TestActivation): class TestCos(TestActivation):
def setUp(self): def setUp(self):
...@@ -1521,8 +1500,7 @@ class TestCos(TestActivation): ...@@ -1521,8 +1500,7 @@ class TestCos(TestActivation):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
# prim not support now self.if_enable_cinn()
self.enable_cinn = False
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
...@@ -1539,6 +1517,9 @@ class TestCos(TestActivation): ...@@ -1539,6 +1517,9 @@ class TestCos(TestActivation):
return return
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
class TestCos_ZeroDim(TestCos): class TestCos_ZeroDim(TestCos):
def init_shape(self): def init_shape(self):
...@@ -1659,8 +1640,7 @@ class TestSin(TestActivation, TestParameter): ...@@ -1659,8 +1640,7 @@ class TestSin(TestActivation, TestParameter):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
# prim not support now self.if_enable_cinn()
self.enable_cinn = False
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
...@@ -1677,6 +1657,9 @@ class TestSin(TestActivation, TestParameter): ...@@ -1677,6 +1657,9 @@ class TestSin(TestActivation, TestParameter):
return return
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
class TestSin_ZeroDim(TestSin): class TestSin_ZeroDim(TestSin):
def init_shape(self): def init_shape(self):
...@@ -1862,9 +1845,6 @@ class TestRelu_ZeroDim(TestRelu): ...@@ -1862,9 +1845,6 @@ class TestRelu_ZeroDim(TestRelu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestReluAPI(unittest.TestCase): class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu # test paddle.nn.ReLU, paddle.nn.functional.relu
...@@ -2141,9 +2121,6 @@ class TestGelu_ZeroDim(TestGelu): ...@@ -2141,9 +2121,6 @@ class TestGelu_ZeroDim(TestGelu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestGELUAPI(unittest.TestCase): class TestGELUAPI(unittest.TestCase):
# test paddle.nn.GELU, paddle.nn.functional.gelu # test paddle.nn.GELU, paddle.nn.functional.gelu
...@@ -2396,7 +2373,6 @@ class TestHardSwish(TestActivation): ...@@ -2396,7 +2373,6 @@ class TestHardSwish(TestActivation):
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.convert_input_output() self.convert_input_output()
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset} self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.enable_cinn = False
def init_shape(self): def init_shape(self):
self.shape = [10, 12] self.shape = [10, 12]
...@@ -2417,10 +2393,6 @@ class TestHardSwish(TestActivation): ...@@ -2417,10 +2393,6 @@ class TestHardSwish(TestActivation):
class TestHardSwish_ZeroDim(TestHardSwish): class TestHardSwish_ZeroDim(TestHardSwish):
def setUp(self):
super().setUp()
self.enable_cinn = False
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
...@@ -2831,9 +2803,6 @@ class TestLog_ZeroDim(TestLog): ...@@ -2831,9 +2803,6 @@ class TestLog_ZeroDim(TestLog):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestLog2(TestActivation): class TestLog2(TestActivation):
def setUp(self): def setUp(self):
...@@ -3131,9 +3100,6 @@ class TestPow_ZeroDim(TestPow): ...@@ -3131,9 +3100,6 @@ class TestPow_ZeroDim(TestPow):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestPow_factor_tensor(TestActivation): class TestPow_factor_tensor(TestActivation):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册