test_conv2d_op.py 3.8 KB
Newer Older
1 2
import unittest
import numpy as np
H
hedaoyuan 已提交
3
from op_test import OpTest
4 5


H
hedaoyuan 已提交
6
class TestConv2dOp(OpTest):
7
    def setUp(self):
H
hedaoyuan 已提交
8
        self.op_type = "conv2d"
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        batch_size = 2
        input_channels = 3
        input_height = 5
        input_width = 5
        output_channels = 6
        filter_height = 3
        filter_width = 3
        stride = 1
        padding = 0
        output_height = (input_height - filter_height + 2 * padding
                         ) / stride + 1
        output_width = (input_width - filter_width + 2 * padding) / stride + 1
        input = np.random.random((batch_size, input_channels, input_height,
                                  input_width)).astype("float32")
        filter = np.random.random(
            (output_channels, input_channels, filter_height,
             filter_width)).astype("float32")
        output = np.ndarray(
            (batch_size, output_channels, output_height, output_width))

        for batchid in xrange(batch_size):
            for channelid in xrange(output_channels):
                for rowid in xrange(output_height):
                    for colid in xrange(output_width):
                        start_h = (rowid * stride) - padding
                        start_w = (colid * stride) - padding
                        output_value = 0.0
                        for inchannelid in xrange(input_channels):
                            for frowid in xrange(filter_height):
                                for fcolid in xrange(filter_width):
                                    input_value = 0.0
                                    inrowid = start_h + frowid
                                    incolid = start_w + fcolid
                                    if ((inrowid >= 0 and
                                         inrowid < input_height) and
                                        (incolid >= 0 and
                                         incolid < input_width)):
                                        input_value = input[batchid][
                                            inchannelid][inrowid][incolid]
                                    filter_value = filter[channelid][
                                        inchannelid][frowid][fcolid]
                                    output_value += input_value * filter_value
                        output[batchid][channelid][rowid][colid] = output_value

        self.inputs = {'Input': input, 'Filter': filter}
        self.outputs = {'Output': output}
        self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}

H
hedaoyuan 已提交
57 58 59
    def test_check_output(self):
        self.check_output()

60

H
hedaoyuan 已提交
61
class TestConv2dGradOp(OpTest):
H
hedaoyuan 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    def setUp(self):
        batch_size = 2
        input_channels = 3
        input_height = 5
        input_width = 5
        output_channels = 6
        filter_height = 3
        filter_width = 3
        stride = 1
        padding = 0
        output_height = (input_height - filter_height + 2 * padding
                         ) / stride + 1
        output_width = (input_width - filter_width + 2 * padding) / stride + 1
        input = np.random.random((batch_size, input_channels, input_height,
                                  input_width)).astype("float32")
        filter = np.random.random(
            (output_channels, input_channels, filter_height,
             filter_width)).astype("float32")

H
hedaoyuan 已提交
81
        self.op_type = 'conv2d'
H
hedaoyuan 已提交
82
        self.inputs = {'Input': input, 'Filter': filter}
H
hedaoyuan 已提交
83 84 85 86
        output = np.ndarray(
            (batch_size, output_channels, output_height, output_width))
        self.outputs = {'Output': output}
        self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}
H
hedaoyuan 已提交
87

H
hedaoyuan 已提交
88 89
    #def test_compare_grad(self):
    #    self.compare_grad(self.op, self.inputs)
H
hedaoyuan 已提交
90 91

    def test_check_grad(self):
H
hedaoyuan 已提交
92
        self.check_grad(set(['Input', 'Filter']), 'Output')
H
hedaoyuan 已提交
93 94


95 96
if __name__ == '__main__':
    unittest.main()