test_conv2d_op.py 3.9 KB
Newer Older
1 2
import unittest
import numpy as np
H
hedaoyuan 已提交
3
from op_test import OpTest
4 5


H
hedaoyuan 已提交
6
class TestConv2dOp(OpTest):
7
    def setUp(self):
H
hedaoyuan 已提交
8
        self.init_groups()
H
hedaoyuan 已提交
9
        self.op_type = "conv2d"
10 11 12 13 14 15 16 17 18 19 20 21 22 23
        batch_size = 2
        input_channels = 3
        input_height = 5
        input_width = 5
        output_channels = 6
        filter_height = 3
        filter_width = 3
        stride = 1
        padding = 0
        output_height = (input_height - filter_height + 2 * padding
                         ) / stride + 1
        output_width = (input_width - filter_width + 2 * padding) / stride + 1
        input = np.random.random((batch_size, input_channels, input_height,
                                  input_width)).astype("float32")
H
hedaoyuan 已提交
24

25
        filter = np.random.random(
H
hedaoyuan 已提交
26
            (output_channels, input_channels / self.groups, filter_height,
27 28 29 30
             filter_width)).astype("float32")
        output = np.ndarray(
            (batch_size, output_channels, output_height, output_width))

H
hedaoyuan 已提交
31
        self.inputs = {'Input': input, 'Filter': filter}
H
hedaoyuan 已提交
32 33 34 35 36
        self.attrs = {
            'strides': [1, 1],
            'paddings': [0, 0],
            'groups': self.groups
        }
H
hedaoyuan 已提交
37

H
hedaoyuan 已提交
38 39
        output_group_channels = output_channels / self.groups
        input_group_channels = input_channels / self.groups
40
        for batchid in xrange(batch_size):
H
hedaoyuan 已提交
41
            for group in xrange(self.groups):
H
hedaoyuan 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
                for outchannelid in range(group * output_group_channels,
                                          (group + 1) * output_group_channels):
                    for rowid in xrange(output_height):
                        for colid in xrange(output_width):
                            start_h = (rowid * stride) - padding
                            start_w = (colid * stride) - padding
                            output_value = 0.0
                            for inchannelid in range(
                                    group * input_group_channels,
                                (group + 1) * input_group_channels):
                                for frowid in xrange(filter_height):
                                    for fcolid in xrange(filter_width):
                                        input_value = 0.0
                                        inrowid = start_h + frowid
                                        incolid = start_w + fcolid
                                        if ((inrowid >= 0 and
                                             inrowid < input_height) and
                                            (incolid >= 0 and
                                             incolid < input_width)):
                                            input_value = input[batchid][
                                                inchannelid][inrowid][incolid]
                                        filter_value = filter[outchannelid][
                                            inchannelid % input_group_channels][
                                                frowid][fcolid]
                                        output_value += input_value * filter_value
                            output[batchid][outchannelid][rowid][
                                colid] = output_value
69 70 71

        self.outputs = {'Output': output}

H
hedaoyuan 已提交
72 73 74
    def test_check_output(self):
        self.check_output()

H
hedaoyuan 已提交
75
    def test_check_grad(self):
76 77
        self.check_grad(
            set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
H
hedaoyuan 已提交
78

79
    def test_check_grad_no_filter(self):
80 81 82 83 84
        self.check_grad(
            ['Input'],
            'Output',
            max_relative_error=0.05,
            no_grad_set=set(['Filter']))
85 86

    def test_check_grad_no_input(self):
87 88 89 90 91
        self.check_grad(
            ['Filter'],
            'Output',
            max_relative_error=0.05,
            no_grad_set=set(['Input']))
92

H
hedaoyuan 已提交
93 94 95 96 97 98 99 100
    def init_groups(self):
        self.groups = 1


class TestWithGroup(TestConv2dOp):
    def init_groups(self):
        self.groups = 3

H
hedaoyuan 已提交
101

102 103
if __name__ == '__main__':
    unittest.main()