提交 cb5e0d20 编写于 作者: Z zhupengyang 提交者: hong19860320

skip cases with small shape (#22318)

上级 8cb04664
...@@ -17,15 +17,16 @@ from __future__ import print_function ...@@ -17,15 +17,16 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import six import six
from op_test import OpTest from op_test import OpTest, skip_check_grad_ci
class PReluTest(OpTest): class PReluTest(OpTest):
def setUp(self): def setUp(self):
self.init_input_shape()
self.init_attr()
self.op_type = "prelu" self.op_type = "prelu"
self.initTestCase()
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
x_np = np.random.uniform(-1, 1, self.x_shape).astype("float32")
# Since zero point in prelu is not differentiable, avoid randomize # Since zero point in prelu is not differentiable, avoid randomize
# zero. # zero.
x_np[np.abs(x_np) < 0.005] = 0.02 x_np[np.abs(x_np) < 0.005] = 0.02
...@@ -37,7 +38,7 @@ class PReluTest(OpTest): ...@@ -37,7 +38,7 @@ class PReluTest(OpTest):
alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
else: else:
alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], \ alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2],
x_np.shape[3]).astype("float32") x_np.shape[3]).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
...@@ -47,7 +48,10 @@ class PReluTest(OpTest): ...@@ -47,7 +48,10 @@ class PReluTest(OpTest):
assert out_np is not self.inputs['X'] assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
def initTestCase(self): def init_input_shape(self):
self.x_shape = (2, 100, 3, 4)
def init_attr(self):
self.attrs = {'mode': "channel"} self.attrs = {'mode': "channel"}
def test_check_output(self): def test_check_output(self):
...@@ -66,16 +70,21 @@ class PReluTest(OpTest): ...@@ -66,16 +70,21 @@ class PReluTest(OpTest):
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues # TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
if six.PY2: if six.PY2:
class TestCase1(PReluTest): @skip_check_grad_ci(
def initTestCase(self): reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAll(PReluTest):
def init_input_shape(self):
self.x_shape = (2, 3, 4, 5)
def init_attr(self):
self.attrs = {'mode': "all"} self.attrs = {'mode': "all"}
class TestCase2(PReluTest): class TestModeElt(PReluTest):
def initTestCase(self): def init_input_shape(self):
self.attrs = {'mode': "channel"} self.x_shape = (3, 2, 5, 10)
class TestCase3(PReluTest): def init_attr(self):
def initTestCase(self):
self.attrs = {'mode': "element"} self.attrs = {'mode': "element"}
......
...@@ -24,8 +24,6 @@ NOT_CHECK_OP_LIST = [ ...@@ -24,8 +24,6 @@ NOT_CHECK_OP_LIST = [
'elementwise_min', 'elementwise_min',
'elementwise_pow', 'elementwise_pow',
'fused_elemwise_activation', 'fused_elemwise_activation',
# prelu op's input alpha must be 1-d and only has one data in 'all' mode
'prelu'
] ]
NEED_TO_FIX_OP_LIST = [ NEED_TO_FIX_OP_LIST = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册