From b4ba35caeb248136461b33c7d47977e09dfb4286 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 17:11:34 +0800 Subject: [PATCH] Add groups test. --- .../v2/framework/tests/test_conv2d_op.py | 58 +++++++++++-------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 29a637a38..660eb3196 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -15,43 +15,53 @@ class TestConv2dOp(OpTest): filter_width = 3 stride = 1 padding = 0 + groups = 3 output_height = (input_height - filter_height + 2 * padding ) / stride + 1 output_width = (input_width - filter_width + 2 * padding) / stride + 1 input = np.random.random((batch_size, input_channels, input_height, input_width)).astype("float32") + filter = np.random.random( - (output_channels, input_channels, filter_height, + (output_channels, input_channels / groups, filter_height, filter_width)).astype("float32") output = np.ndarray( (batch_size, output_channels, output_height, output_width)) self.inputs = {'Input': input, 'Filter': filter} - self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} + self.attrs = {'strides': [1, 1], 'paddings': [0, 0], 'groups': groups} + output_group_channels = output_channels / groups + input_group_channels = input_channels / groups for batchid in xrange(batch_size): - for channelid in xrange(output_channels): - for rowid in xrange(output_height): - for colid in xrange(output_width): - start_h = (rowid * stride) - padding - start_w = (colid * stride) - padding - output_value = 0.0 - for inchannelid in xrange(input_channels): - for frowid in xrange(filter_height): - for fcolid in xrange(filter_width): - input_value = 0.0 - inrowid = start_h + frowid - incolid = start_w + fcolid - if ((inrowid >= 0 and - inrowid < input_height) and - (incolid >= 0 and - incolid < input_width)): - input_value = input[batchid][ - inchannelid][inrowid][incolid] - filter_value = filter[channelid][ - inchannelid][frowid][fcolid] - output_value += input_value * filter_value - output[batchid][channelid][rowid][colid] = output_value + for group in xrange(groups): + for outchannelid in range(group * output_group_channels, + (group + 1) * output_group_channels): + for rowid in xrange(output_height): + for colid in xrange(output_width): + start_h = (rowid * stride) - padding + start_w = (colid * stride) - padding + output_value = 0.0 + for inchannelid in range( + group * input_group_channels, + (group + 1) * input_group_channels): + for frowid in xrange(filter_height): + for fcolid in xrange(filter_width): + input_value = 0.0 + inrowid = start_h + frowid + incolid = start_w + fcolid + if ((inrowid >= 0 and + inrowid < input_height) and + (incolid >= 0 and + incolid < input_width)): + input_value = input[batchid][ + inchannelid][inrowid][incolid] + filter_value = filter[outchannelid][ + inchannelid % input_group_channels][ + frowid][fcolid] + output_value += input_value * filter_value + output[batchid][outchannelid][rowid][ + colid] = output_value self.outputs = {'Output': output} -- GitLab