提交 cb5e0d20 编写于 作者: Z zhupengyang 提交者: hong19860320

skip cases with small shape (#22318)

上级 8cb04664
......@@ -231,12 +231,12 @@ class OpTest(unittest.TestCase):
"This test of %s op needs check_grad with fp64 precision." %
cls.op_type)
if not get_numeric_gradient.is_large_shape \
and cls.op_type not in check_shape_white_list.NOT_CHECK_OP_LIST \
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
raise AssertionError(
"Input's shape should be large than or equal to 100 for " +
cls.op_type + " Op.")
if not get_numeric_gradient.is_large_shape \
and cls.op_type not in check_shape_white_list.NOT_CHECK_OP_LIST \
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
raise AssertionError(
"Input's shape should be large than or equal to 100 for " +
cls.op_type + " Op.")
def try_call_once(self, data_type):
if not self.call_once:
......
......@@ -17,15 +17,16 @@ from __future__ import print_function
import unittest
import numpy as np
import six
from op_test import OpTest
from op_test import OpTest, skip_check_grad_ci
class PReluTest(OpTest):
def setUp(self):
self.init_input_shape()
self.init_attr()
self.op_type = "prelu"
self.initTestCase()
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
x_np = np.random.uniform(-1, 1, self.x_shape).astype("float32")
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np[np.abs(x_np) < 0.005] = 0.02
......@@ -37,8 +38,8 @@ class PReluTest(OpTest):
alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np}
else:
alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], \
x_np.shape[3]).astype("float32")
alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2],
x_np.shape[3]).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.)
......@@ -47,7 +48,10 @@ class PReluTest(OpTest):
assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np}
def initTestCase(self):
def init_input_shape(self):
self.x_shape = (2, 100, 3, 4)
def init_attr(self):
self.attrs = {'mode': "channel"}
def test_check_output(self):
......@@ -66,16 +70,21 @@ class PReluTest(OpTest):
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
if six.PY2:
class TestCase1(PReluTest):
def initTestCase(self):
@skip_check_grad_ci(
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class TestModeAll(PReluTest):
def init_input_shape(self):
self.x_shape = (2, 3, 4, 5)
def init_attr(self):
self.attrs = {'mode': "all"}
class TestCase2(PReluTest):
def initTestCase(self):
self.attrs = {'mode': "channel"}
class TestModeElt(PReluTest):
def init_input_shape(self):
self.x_shape = (3, 2, 5, 10)
class TestCase3(PReluTest):
def initTestCase(self):
def init_attr(self):
self.attrs = {'mode': "element"}
......
......@@ -24,8 +24,6 @@ NOT_CHECK_OP_LIST = [
'elementwise_min',
'elementwise_pow',
'fused_elemwise_activation',
# prelu op's input alpha must be 1-d and only has one data in 'all' mode
'prelu'
]
NEED_TO_FIX_OP_LIST = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册