diff --git a/test/xpu/test_activation_op_xpu.py b/test/xpu/test_activation_op_xpu.py index 50d9aec5b66ff362ca39e0dd0676fcf899f48132..d6ece4edf810796841dae695850507ec612c778a 100644 --- a/test/xpu/test_activation_op_xpu.py +++ b/test/xpu/test_activation_op_xpu.py @@ -364,31 +364,50 @@ class XPUTestGeluOP(XPUOpTestWrapper): self.op_name = 'gelu' self.use_dynamic_create_class = False - class XPUTestGelu(TestActivationOPBase): + class XPUTestGeluBase(TestActivationOPBase): def set_case(self): self.op_type = "gelu" self.dtype = self.in_type - approximate = False - x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - out = gelu(x, approximate) + self.init_config() + out = gelu(self.x, self.approximate) - self.inputs = {'X': x} + self.inputs = {'X': self.x} self.outputs = {'Out': out} - self.attrs = {"approximate": approximate, 'use_xpu': True} + self.attrs = {"approximate": self.approximate, 'use_xpu': True} - class XPUTestGeluApproximate(TestActivationOPBase): - def set_case(self): - self.op_type = "gelu" - self.dtype = self.in_type + def init_config(self): + self.approximate = False + self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - approximate = True - x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - out = gelu(x, approximate) + class XPUTestGelu_ZeroDim(XPUTestGeluBase): + def init_config(self): + self.approximate = False + self.x = np.random.uniform(-2, 2, []).astype(self.dtype) - self.inputs = {'X': x} - self.outputs = {'Out': out} - self.attrs = {"approximate": approximate, 'use_xpu': True} + class XPUTestGelu1(XPUTestGeluBase): + def init_config(self): + self.approximate = True + self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + + class XPUTestGelu2(XPUTestGeluBase): + def init_config(self): + self.approximate = False + self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype) + + class XPUTestGelu3(XPUTestGeluBase): + def init_config(self): + self.approximate = True + self.x = np.random.uniform(-2, 2, [4, 512, 15, 15]).astype( + self.dtype + ) + + class XPUTestGelu4(XPUTestGeluBase): + def init_config(self): + self.approximate = False + self.x = np.random.uniform(-2, 2, [4, 256, 22, 22]).astype( + self.dtype + ) support_types = get_xpu_op_support_types('gelu')