test_conv2d_transpose_op.py 4.9 KB
Newer Older
Z
deconv  
zchen0211 已提交
1 2 3 4 5
import unittest
import numpy as np
from op_test import OpTest


C
chengduoZH 已提交
6
def conv2dtranspose_forward_naive(input_, filter_, attrs):
Z
deconv  
zchen0211 已提交
7 8 9 10
    in_n, in_c, in_h, in_w = input_.shape
    f_c, out_c, f_h, f_w = filter_.shape
    assert in_c == f_c

C
chengduoZH 已提交
11 12 13 14 15 16
    stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
        'dilations']
    d_bolck_h = dilations[0] * (f_h - 1) + 1
    d_bolck_w = dilations[1] * (f_w - 1) + 1
    out_h = (in_h - 1) * stride[0] + d_bolck_h
    out_w = (in_w - 1) * stride[1] + d_bolck_w
Z
deconv  
zchen0211 已提交
17 18 19 20 21 22 23 24 25 26 27 28

    out = np.zeros((in_n, out_c, out_h, out_w))

    for n in range(in_n):
        for i in range(in_h):
            for j in range(in_w):
                input_masked = input_[n, :, i, j]  # (c)
                input_masked = np.reshape(input_masked, (in_c, 1, 1))
                input_masked = np.tile(input_masked, (1, f_h, f_w))

                for k in range(out_c):
                    tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0)
C
chengduoZH 已提交
29 30 31
                    i1, i2 = i * stride[0], i * stride[0] + d_bolck_h
                    j1, j2 = j * stride[0], j * stride[0] + d_bolck_h
                    out[n, k, i1:i2:dilations[0], j1:j2:dilations[1]] += tmp_out
Z
deconv  
zchen0211 已提交
32

C
chengduoZH 已提交
33
    out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]]
Z
deconv  
zchen0211 已提交
34 35 36
    return out


Z
zchen0211 已提交
37
class TestConv2dTransposeOp(OpTest):
Z
deconv  
zchen0211 已提交
38
    def setUp(self):
Z
zchen0211 已提交
39
        # init as conv transpose
Z
deconv  
zchen0211 已提交
40 41 42 43 44 45 46 47 48 49
        self.init_op_type()
        self.init_test_case()

        input_ = np.random.random(self.input_size).astype("float32")
        filter_ = np.random.random(self.filter_size).astype("float32")

        self.inputs = {'Input': input_, 'Filter': filter_}
        self.attrs = {
            'strides': self.stride,
            'paddings': self.pad,
Z
zchen0211 已提交
50
            'dilations': self.dilations
Z
deconv  
zchen0211 已提交
51
        }
C
chengduoZH 已提交
52 53 54 55

        output = conv2dtranspose_forward_naive(input_, filter_,
                                               self.attrs).astype('float32')

Z
deconv  
zchen0211 已提交
56 57 58 59 60
        self.outputs = {'Output': output}

    def test_check_output(self):
        self.check_output()

Z
zchen0211 已提交
61 62 63 64
    def test_check_grad_no_input(self):
        self.check_grad(
            ['Filter'],
            'Output',
C
chengduoZH 已提交
65
            max_relative_error=0.02,
Z
zchen0211 已提交
66 67 68 69 70 71
            no_grad_set=set(['Input']))

    def test_check_grad_no_filter(self):
        self.check_grad(
            ['Input'],
            'Output',
C
chengduoZH 已提交
72
            max_relative_error=0.02,
Z
zchen0211 已提交
73
            no_grad_set=set(['Filter']))
Z
deconv  
zchen0211 已提交
74

Z
zchen0211 已提交
75 76
    def test_check_grad(self):
        self.check_grad(
C
chengduoZH 已提交
77 78 79 80 81 82 83 84 85 86 87 88
            set(['Input', 'Filter']), 'Output', max_relative_error=0.02)

    def init_test_case(self):
        self.pad = [0, 0]
        self.stride = [1, 1]
        self.dilations = [1, 1]
        self.input_size = [2, 3, 5, 5]  # NCHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3]

    def init_op_type(self):
        self.op_type = "conv2d_transpose"
Z
deconv  
zchen0211 已提交
89

Z
zchen0211 已提交
90

C
chengduoZH 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
class TestWithPad(TestConv2dTransposeOp):
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [1, 1]
        self.dilations = [1, 1]
        self.input_size = [2, 3, 5, 5]  # NCHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3]


class TestWithStride(TestConv2dTransposeOp):
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [2, 2]
        self.dilations = [1, 1]
        self.input_size = [2, 3, 5, 5]  # NCHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3]


C
chengduoZH 已提交
111 112 113 114 115 116 117 118 119 120
class TestWithDilation(TestConv2dTransposeOp):
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [1, 1]
        self.dilations = [2, 2]
        self.input_size = [2, 3, 5, 5]  # NCHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3]


C
chengduoZH 已提交
121
# ------------ test_cudnn ------------
Z
zchen0211 已提交
122
class TestCudnn(TestConv2dTransposeOp):
Z
deconv  
zchen0211 已提交
123
    def init_op_type(self):
Z
zchen0211 已提交
124
        self.op_type = "conv2d_transpose_cudnn"
Z
zchen0211 已提交
125

Z
deconv  
zchen0211 已提交
126

C
chengduoZH 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
class TestCudnnWithPad(TestWithPad):
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [1, 1]
        self.dilations = [1, 1]
        self.input_size = [2, 3, 5, 5]  # NCHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3]

    def init_op_type(self):
        self.op_type = "conv2d_transpose_cudnn"


class TestCudnnWithStride(TestWithStride):
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [2, 2]
        self.dilations = [1, 1]
        self.input_size = [2, 3, 5, 5]  # NCHW
        f_c = self.input_size[1]
        self.filter_size = [f_c, 6, 3, 3]

    def init_op_type(self):
        self.op_type = "conv2d_transpose_cudnn"


# #cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
#     def init_test_case(self):
#         self.pad = [1, 1]
#         self.stride = [2, 2]
#         self.dilations = [2, 2]
#         self.input_size = [2, 3, 5, 5]  # NCHW
#         f_c = self.input_size[1]
#         self.filter_size = [f_c, 6, 3, 3]
#
#     def init_op_type(self):
#         self.op_type = "conv2d_transpose_cudnn"

Z
deconv  
zchen0211 已提交
166 167
if __name__ == '__main__':
    unittest.main()