test_conv2d_op.py 4.2 KB
Newer Older
1 2
import unittest
import numpy as np
H
hedaoyuan 已提交
3
from op_test import OpTest
4 5


H
hedaoyuan 已提交
6
class TestConv2dOp(OpTest):
7
    def setUp(self):
H
hedaoyuan 已提交
8
        self.init_groups()
武毅 已提交
9
        self.init_optype()
10 11 12 13 14 15 16 17 18 19 20 21 22 23
        batch_size = 2
        input_channels = 3
        input_height = 5
        input_width = 5
        output_channels = 6
        filter_height = 3
        filter_width = 3
        stride = 1
        padding = 0
        output_height = (input_height - filter_height + 2 * padding
                         ) / stride + 1
        output_width = (input_width - filter_width + 2 * padding) / stride + 1
        input = np.random.random((batch_size, input_channels, input_height,
                                  input_width)).astype("float32")
H
hedaoyuan 已提交
24

25
        filter = np.random.random(
H
hedaoyuan 已提交
26
            (output_channels, input_channels / self.groups, filter_height,
27 28 29 30
             filter_width)).astype("float32")
        output = np.ndarray(
            (batch_size, output_channels, output_height, output_width))

H
hedaoyuan 已提交
31
        self.inputs = {'Input': input, 'Filter': filter}
H
hedaoyuan 已提交
32 33 34
        self.attrs = {
            'strides': [1, 1],
            'paddings': [0, 0],
武毅 已提交
35
            'dilations': [1, 1],
H
hedaoyuan 已提交
36 37
            'groups': self.groups
        }
H
hedaoyuan 已提交
38

H
hedaoyuan 已提交
39 40
        output_group_channels = output_channels / self.groups
        input_group_channels = input_channels / self.groups
41
        for batchid in xrange(batch_size):
H
hedaoyuan 已提交
42
            for group in xrange(self.groups):
H
hedaoyuan 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
                for outchannelid in range(group * output_group_channels,
                                          (group + 1) * output_group_channels):
                    for rowid in xrange(output_height):
                        for colid in xrange(output_width):
                            start_h = (rowid * stride) - padding
                            start_w = (colid * stride) - padding
                            output_value = 0.0
                            for inchannelid in range(
                                    group * input_group_channels,
                                (group + 1) * input_group_channels):
                                for frowid in xrange(filter_height):
                                    for fcolid in xrange(filter_width):
                                        input_value = 0.0
                                        inrowid = start_h + frowid
                                        incolid = start_w + fcolid
                                        if ((inrowid >= 0 and
                                             inrowid < input_height) and
                                            (incolid >= 0 and
                                             incolid < input_width)):
                                            input_value = input[batchid][
                                                inchannelid][inrowid][incolid]
                                        filter_value = filter[outchannelid][
                                            inchannelid % input_group_channels][
                                                frowid][fcolid]
                                        output_value += input_value * filter_value
                            output[batchid][outchannelid][rowid][
                                colid] = output_value
70 71 72

        self.outputs = {'Output': output}

H
hedaoyuan 已提交
73 74 75
    def test_check_output(self):
        self.check_output()

H
hedaoyuan 已提交
76
    def test_check_grad(self):
77 78
        self.check_grad(
            set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
H
hedaoyuan 已提交
79

80
    def test_check_grad_no_filter(self):
81 82 83 84 85
        self.check_grad(
            ['Input'],
            'Output',
            max_relative_error=0.05,
            no_grad_set=set(['Filter']))
86 87

    def test_check_grad_no_input(self):
88 89 90 91 92
        self.check_grad(
            ['Filter'],
            'Output',
            max_relative_error=0.05,
            no_grad_set=set(['Input']))
93

H
hedaoyuan 已提交
94 95 96
    def init_groups(self):
        self.groups = 1

武毅 已提交
97 98 99
    def init_optype(self):
        self.op_type = "conv2d"

H
hedaoyuan 已提交
100 101 102 103 104

class TestWithGroup(TestConv2dOp):
    def init_groups(self):
        self.groups = 3

H
hedaoyuan 已提交
105

武毅 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118
class TestCudnn2d(TestConv2dOp):
    def init_optype(self):
        self.op_type = "conv_cudnn"


class TestCudnn2dWithGroup(TestConv2dOp):
    def init_optype(self):
        self.op_type = "conv_cudnn"

    def init_groups(self):
        self.groups = 3


119 120
if __name__ == '__main__':
    unittest.main()