提交 2eed2a5d 编写于 作者: T Taehoon Lee 提交者: François Chollet

Increase test coverages by adding invalid CNTK usecases (#10236)

上级 bf1378f3
......@@ -1195,33 +1195,46 @@ class TestBackend(object):
assert np.abs(np.std(rand) - std * 0.87962) < 0.015
def test_conv_invalid_use(self):
dummy_x_1d = K.variable(np.ones((4, 8, 2)))
dummy_w_1d = K.variable(np.ones((3, 2, 3)))
dummy_x_2d = K.variable(np.ones((2, 3, 4, 5)))
dummy_w_2d = K.variable(np.ones((2, 2, 3, 4)))
dummy_x_3d = K.variable(np.ones((2, 3, 4, 5, 4)))
dummy_w_3d = K.variable(np.ones((2, 2, 2, 3, 4)))
dummy_w1x1_2d = K.variable(np.ones((1, 1, 12, 7)))
with pytest.raises(ValueError):
K.conv1d(K.variable(np.ones((4, 8, 2))),
K.variable(np.ones((3, 2, 3))),
data_format='channels_middle')
K.conv1d(dummy_x_1d, dummy_w_1d, data_format='channels_middle')
with pytest.raises(ValueError):
K.conv2d(K.variable(np.ones((2, 3, 4, 5))),
K.variable(np.ones((2, 2, 3, 4))),
data_format='channels_middle')
K.conv2d(dummy_x_2d, dummy_w_2d, data_format='channels_middle')
with pytest.raises(ValueError):
K.conv3d(K.variable(np.ones((2, 3, 4, 5, 4))),
K.variable(np.ones((2, 2, 2, 3, 4))),
data_format='channels_middle')
K.conv3d(dummy_x_3d, dummy_w_3d, data_format='channels_middle')
if K.backend() != 'theano':
with pytest.raises(ValueError):
K.separable_conv2d(K.variable(np.ones((2, 3, 4, 5))),
K.variable(np.ones((2, 2, 3, 4))),
K.variable(np.ones((1, 1, 12, 7))),
K.separable_conv2d(dummy_x_2d, dummy_w_2d, dummy_w1x1_2d,
data_format='channels_middle')
with pytest.raises(ValueError):
K.depthwise_conv2d(K.variable(np.ones((2, 3, 4, 5))),
K.variable(np.ones((2, 2, 3, 4))),
K.depthwise_conv2d(dummy_x_2d, dummy_w_2d,
data_format='channels_middle')
if K.backend() == 'cntk':
with pytest.raises(ValueError):
K.separable_conv2d(dummy_x_2d, dummy_w_2d, dummy_w1x1_2d,
dilation_rate=(1, 2))
with pytest.raises(ValueError):
K.separable_conv2d(dummy_x_2d, dummy_w_2d, dummy_w1x1_2d,
strides=(2, 2), dilation_rate=(1, 2))
with pytest.raises(ValueError):
K.depthwise_conv2d(dummy_x_2d, dummy_w_2d,
dilation_rate=(1, 2))
with pytest.raises(ValueError):
K.depthwise_conv2d(dummy_x_2d, dummy_w_2d,
strides=(2, 2), dilation_rate=(1, 2))
def test_pooling_invalid_use(self):
for (input_shape, pool_size) in zip([(5, 10, 12, 3), (5, 10, 12, 6, 3)], [(2, 2), (2, 2, 2)]):
x = K.variable(np.random.random(input_shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册