From a7c1872206cf11ba968a932a0fc880a03e8a4c28 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 12 Sep 2017 21:05:54 +0800 Subject: [PATCH] Refine test_conv2d_op.py --- .../v2/framework/tests/test_conv2d_op.py | 36 ++----------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 01513be66e0..29a637a3822 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -26,6 +26,9 @@ class TestConv2dOp(OpTest): output = np.ndarray( (batch_size, output_channels, output_height, output_width)) + self.inputs = {'Input': input, 'Filter': filter} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + for batchid in xrange(batch_size): for channelid in xrange(output_channels): for rowid in xrange(output_height): @@ -50,44 +53,11 @@ class TestConv2dOp(OpTest): output_value += input_value * filter_value output[batchid][channelid][rowid][colid] = output_value - self.inputs = {'Input': input, 'Filter': filter} self.outputs = {'Output': output} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} def test_check_output(self): self.check_output() - -class TestConv2dGradOp(OpTest): - def setUp(self): - batch_size = 2 - input_channels = 3 - input_height = 5 - input_width = 5 - output_channels = 6 - filter_height = 3 - filter_width = 3 - stride = 1 - padding = 0 - output_height = (input_height - filter_height + 2 * padding - ) / stride + 1 - output_width = (input_width - filter_width + 2 * padding) / stride + 1 - input = np.random.random((batch_size, input_channels, input_height, - input_width)).astype("float32") - filter = np.random.random( - (output_channels, input_channels, filter_height, - filter_width)).astype("float32") - - self.op_type = 'conv2d' - self.inputs = {'Input': input, 'Filter': filter} - output = np.ndarray( - (batch_size, output_channels, output_height, output_width)) - self.outputs = {'Output': output} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} - - #def test_compare_grad(self): - # self.compare_grad(self.op, self.inputs) - def test_check_grad(self): self.check_grad(set(['Input', 'Filter']), 'Output') -- GitLab