test_block_expand_op.py 5.3 KB
Newer Older
G
gongweibao 已提交
1 2 3 4 5
import unittest
import numpy as np
from op_test import OpTest


G
gongweibao 已提交
6 7 8 9 10 11 12 13 14 15 16 17
def get_output_shape(attrs, x):
    imgHeight = x.shape[1]
    imgWidth = x.shape[2]

    paddingHeight = attrs['paddingHeight']
    paddingWidth = attrs['paddingWidth']
    blockHeight = attrs['blockHeight']
    blockWidth = attrs['blockWidth']
    strideHeight = attrs['strideHeight']
    strideWidth = attrs['strideWidth']

    outputHeight = \
G
gongweibao 已提交
18
      1 +  \
G
gongweibao 已提交
19 20
      (imgHeight + 2 * paddingHeight - blockHeight + strideHeight - 1) / \
          strideHeight
G
gongweibao 已提交
21

G
gongweibao 已提交
22
    outputWidth = \
G
gongweibao 已提交
23
      1 + \
G
gongweibao 已提交
24 25
      (imgWidth + 2 * paddingWidth - blockWidth + strideWidth - 1) / \
          strideWidth
G
gongweibao 已提交
26

G
gongweibao 已提交
27
    return outputHeight, outputWidth
G
gongweibao 已提交
28 29


G
gongweibao 已提交
30
def im2col(attrs, im, col):
G
gongweibao 已提交
31 32 33 34 35
    """
    im: {CHW}
    col:
        {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
    """
G
gongweibao 已提交
36 37 38 39 40 41 42 43
    input_channels = im.shape[0]
    inputHeight = im.shape[1]
    inputWidth = im.shape[2]

    outputHeight = col.shape[0]
    outputWidth = col.shape[1]
    filterHeight = col.shape[3]
    filterWidth = col.shape[4]
G
gongweibao 已提交
44

G
gongweibao 已提交
45 46 47 48 49 50 51
    strideHeight = attrs['strideHeight']
    strideWidth = attrs['strideWidth']
    paddingHeight = attrs['paddingHeight']
    paddingWidth = attrs['paddingWidth']

    for col_row_idx in range(0, outputHeight):
        for col_col_idx in range(0, outputWidth):
G
gongweibao 已提交
52
            for channel in range(0, input_channels):
G
gongweibao 已提交
53 54 55 56 57 58 59 60 61
                for filter_row_idx in range(0, filterHeight):
                    for filter_col_idx in range(0, filterWidth):
                        im_row_offset = col_row_idx * strideHeight \
                            + filter_row_idx - paddingHeight

                        im_col_offset = col_col_idx * strideWidth \
                            + filter_col_idx - paddingWidth

                        if (im_row_offset < 0 or im_row_offset >= inputHeight or
G
gongweibao 已提交
62
                                im_col_offset < 0 or
G
gongweibao 已提交
63 64
                                im_col_offset >= inputWidth):
                            col[col_row_idx][col_col_idx][channel][\
G
gongweibao 已提交
65 66
                                filter_row_idx][filter_col_idx] = 0.0
                        else:
G
gongweibao 已提交
67 68 69 70 71
                            im_offset = (channel * inputHeight + im_row_offset \
                                         ) * inputWidth + im_col_offset

                            col[col_row_idx][col_col_idx][channel][\
                                filter_row_idx][filter_col_idx] = im[channel][ \
G
gongweibao 已提交
72 73 74 75
                                    im_row_offset][im_col_offset]


def col2img(attrs, col, img):
G
gongweibao 已提交
76 77 78 79 80
    """
    img: {CHW}
    col:
        {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
    """
G
gongweibao 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    input_channels = im.shape[0]
    inputHeight = im.shape[1]
    inputWidth = im.shape[2]

    outputHeight = col.shape[0]
    outputWidth = col.shape[1]
    filterHeight = col.shape[3]
    filterWidth = col.shape[4]

    strideHeight = attrs['strideHeight']
    strideWidth = attrs['strideWidth']
    paddingHeight = attrs['paddingHeight']
    paddingWidth = attrs['paddingWidth']

    for col_row_idx in range(0, outputHeight):
        for col_col_idx in range(0, outputWidth):
G
gongweibao 已提交
97
            for channel in range(0, input_channels):
G
gongweibao 已提交
98 99
                for filter_row_idx in range(0, filterHeight):
                    for filter_col_idx in range(0, filterWidth):
G
gongweibao 已提交
100
                        im_row_offset = \
G
gongweibao 已提交
101
                            col_row_idx * strideHeight + filter_row_idx - paddingHeight
G
gongweibao 已提交
102
                        im_col_offset = \
G
gongweibao 已提交
103
                            col_col_idx * strideWidth + filter_col_idx - paddingWidth
G
gongweibao 已提交
104
                        if (im_row_offset >= 0 and
G
gongweibao 已提交
105
                                im_row_offset < inputHeight and
G
gongweibao 已提交
106
                                im_col_offset >= 0 and
G
gongweibao 已提交
107
                                im_col_offset < inputWidth):
G
gongweibao 已提交
108 109 110 111
                            im[channel][im_row_offset][im_col_offset] = \
                                col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx]


G
gongweibao 已提交
112 113 114 115 116 117 118 119 120 121
class TestBlockExpandOp(OpTest):
    def get_input_data(self, C, H, W):
        x = np.random.uniform(0.1, 1, [C, H, W]).astype("float32")
        for c in range(0, C):
            for h in range(0, H):
                for w in range(0, W):
                    #x[c][h][w] = c * H * W + h *W + w
                    x[c][h][w] = 0.2 + 0.01 * (c * H * W + h * W + w)
        return x

G
gongweibao 已提交
122
    def setUp(self):
G
gongweibao 已提交
123 124 125 126 127 128
        C = 3
        H = 4
        W = 4
        x = self.get_input_data(C, H, W)
        #print x

G
gongweibao 已提交
129
        attrs = {
G
gongweibao 已提交
130 131 132 133 134 135
            'blockHeight': 2,
            'blockWidth': 2,
            'strideHeight': 1,
            'strideWidth': 1,
            'paddingHeight': 1,
            'paddingWidth': 1,
G
gongweibao 已提交
136 137
        }

G
gongweibao 已提交
138 139 140 141 142 143
        outputHeight, outputWidth = get_output_shape(attrs, x)
        out = np.random.uniform(0.1, 1,\
                    [outputHeight, outputWidth, x.shape[0], \
                     attrs['blockHeight'], attrs['blockWidth']]).astype("float32")

        self.op_type = "block_expand"
G
gongweibao 已提交
144
        self.inputs = {'X': x.reshape(1, C, H, W)}
G
gongweibao 已提交
145 146 147 148 149 150 151
        self.attrs = attrs

        im2col(attrs, x, out)
        self.outputs = {
            'Out':out.reshape(1, outputHeight, outputWidth, x.shape[0], \
                     attrs['blockHeight'], attrs['blockWidth'])
            }
G
gongweibao 已提交
152 153 154

    def test_check_output(self):
        self.check_output()
G
gongweibao 已提交
155

G
gongweibao 已提交
156
    def test_check_grad_normal(self):
G
gongweibao 已提交
157
        self.check_grad(['X'], 'Out')
G
gongweibao 已提交
158 159 160 161


if __name__ == '__main__':
    unittest.main()