未验证 提交 80975d45 编写于 作者: H houj04 提交者: GitHub

[XPU] optimize gelu unittest. (#54737)

上级 64696b9b
...@@ -364,31 +364,50 @@ class XPUTestGeluOP(XPUOpTestWrapper): ...@@ -364,31 +364,50 @@ class XPUTestGeluOP(XPUOpTestWrapper):
self.op_name = 'gelu' self.op_name = 'gelu'
self.use_dynamic_create_class = False self.use_dynamic_create_class = False
class XPUTestGelu(TestActivationOPBase): class XPUTestGeluBase(TestActivationOPBase):
def set_case(self): def set_case(self):
self.op_type = "gelu" self.op_type = "gelu"
self.dtype = self.in_type self.dtype = self.in_type
approximate = False self.init_config()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) out = gelu(self.x, self.approximate)
out = gelu(x, approximate)
self.inputs = {'X': x} self.inputs = {'X': self.x}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = {"approximate": approximate, 'use_xpu': True} self.attrs = {"approximate": self.approximate, 'use_xpu': True}
class XPUTestGeluApproximate(TestActivationOPBase): def init_config(self):
def set_case(self): self.approximate = False
self.op_type = "gelu" self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
self.dtype = self.in_type
approximate = True class XPUTestGelu_ZeroDim(XPUTestGeluBase):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) def init_config(self):
out = gelu(x, approximate) self.approximate = False
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
self.inputs = {'X': x} class XPUTestGelu1(XPUTestGeluBase):
self.outputs = {'Out': out} def init_config(self):
self.attrs = {"approximate": approximate, 'use_xpu': True} self.approximate = True
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestGelu2(XPUTestGeluBase):
def init_config(self):
self.approximate = False
self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype)
class XPUTestGelu3(XPUTestGeluBase):
def init_config(self):
self.approximate = True
self.x = np.random.uniform(-2, 2, [4, 512, 15, 15]).astype(
self.dtype
)
class XPUTestGelu4(XPUTestGeluBase):
def init_config(self):
self.approximate = False
self.x = np.random.uniform(-2, 2, [4, 256, 22, 22]).astype(
self.dtype
)
support_types = get_xpu_op_support_types('gelu') support_types = get_xpu_op_support_types('gelu')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册