test_conv2d_op.py 3.5 KB
Newer Older
1 2
import unittest
import numpy as np
H
hedaoyuan 已提交
3
from op_test import OpTest
4 5


H
hedaoyuan 已提交
6
class TestConv2dOp(OpTest):
7
    def setUp(self):
H
hedaoyuan 已提交
8
        self.init_groups()
H
hedaoyuan 已提交
9
        self.op_type = "conv2d"
10 11 12 13 14 15 16 17 18 19 20 21 22 23
        batch_size = 2
        input_channels = 3
        input_height = 5
        input_width = 5
        output_channels = 6
        filter_height = 3
        filter_width = 3
        stride = 1
        padding = 0
        output_height = (input_height - filter_height + 2 * padding
                         ) / stride + 1
        output_width = (input_width - filter_width + 2 * padding) / stride + 1
        input = np.random.random((batch_size, input_channels, input_height,
                                  input_width)).astype("float32")
H
hedaoyuan 已提交
24

25
        filter = np.random.random(
H
hedaoyuan 已提交
26
            (output_channels, input_channels / self.groups, filter_height,
27 28 29 30
             filter_width)).astype("float32")
        output = np.ndarray(
            (batch_size, output_channels, output_height, output_width))

H
hedaoyuan 已提交
31
        self.inputs = {'Input': input, 'Filter': filter}
H
hedaoyuan 已提交
32 33 34 35 36
        self.attrs = {
            'strides': [1, 1],
            'paddings': [0, 0],
            'groups': self.groups
        }
H
hedaoyuan 已提交
37

H
hedaoyuan 已提交
38 39
        output_group_channels = output_channels / self.groups
        input_group_channels = input_channels / self.groups
40
        for batchid in xrange(batch_size):
H
hedaoyuan 已提交
41
            for group in xrange(self.groups):
H
hedaoyuan 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
                for outchannelid in range(group * output_group_channels,
                                          (group + 1) * output_group_channels):
                    for rowid in xrange(output_height):
                        for colid in xrange(output_width):
                            start_h = (rowid * stride) - padding
                            start_w = (colid * stride) - padding
                            output_value = 0.0
                            for inchannelid in range(
                                    group * input_group_channels,
                                (group + 1) * input_group_channels):
                                for frowid in xrange(filter_height):
                                    for fcolid in xrange(filter_width):
                                        input_value = 0.0
                                        inrowid = start_h + frowid
                                        incolid = start_w + fcolid
                                        if ((inrowid >= 0 and
                                             inrowid < input_height) and
                                            (incolid >= 0 and
                                             incolid < input_width)):
                                            input_value = input[batchid][
                                                inchannelid][inrowid][incolid]
                                        filter_value = filter[outchannelid][
                                            inchannelid % input_group_channels][
                                                frowid][fcolid]
                                        output_value += input_value * filter_value
                            output[batchid][outchannelid][rowid][
                                colid] = output_value
69 70 71

        self.outputs = {'Output': output}

H
hedaoyuan 已提交
72 73 74
    def test_check_output(self):
        self.check_output()

H
hedaoyuan 已提交
75
    def test_check_grad(self):
H
hedaoyuan 已提交
76
        self.check_grad(set(['Input', 'Filter']), 'Output')
H
hedaoyuan 已提交
77

H
hedaoyuan 已提交
78 79 80 81 82 83 84 85
    def init_groups(self):
        self.groups = 1


class TestWithGroup(TestConv2dOp):
    def init_groups(self):
        self.groups = 3

H
hedaoyuan 已提交
86

87 88
if __name__ == '__main__':
    unittest.main()