test_im2sequence_op.py 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
G
gongweibao 已提交
14 15 16 17 18
import unittest
import numpy as np
from op_test import OpTest


19 20 21
def get_output_shape(attrs, in_shape):
    img_height = in_shape[2]
    img_width = in_shape[3]
G
gongweibao 已提交
22

W
wanghaoshuang 已提交
23 24 25 26 27 28
    padding_height = attrs['padding_height']
    padding_width = attrs['padding_width']
    block_height = attrs['block_height']
    block_width = attrs['block_width']
    stride_height = attrs['stride_height']
    stride_width = attrs['stride_width']
G
gongweibao 已提交
29

G
gongweibao 已提交
30
    output_height = \
G
gongweibao 已提交
31
      1 +  \
G
gongweibao 已提交
32
      (img_height + 2 * padding_height - block_height + stride_height - 1) / \
W
wanghaoshuang 已提交
33
          stride_height
G
gongweibao 已提交
34

G
gongweibao 已提交
35
    output_width = \
G
gongweibao 已提交
36
      1 + \
G
gongweibao 已提交
37 38
      (img_width + 2 * padding_width - block_width + stride_width - 1) / \
          stride_width
G
gongweibao 已提交
39

G
gongweibao 已提交
40
    return output_height, output_width
G
gongweibao 已提交
41 42


G
gongweibao 已提交
43
def im2col(attrs, im, col):
G
gongweibao 已提交
44 45 46 47 48
    """
    im: {CHW}
    col:
        {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
    """
G
gongweibao 已提交
49
    input_channels = im.shape[0]
G
gongweibao 已提交
50 51
    input_height = im.shape[1]
    input_width = im.shape[2]
G
gongweibao 已提交
52

G
gongweibao 已提交
53 54 55 56
    output_height = col.shape[0]
    output_width = col.shape[1]
    filter_height = col.shape[3]
    filter_width = col.shape[4]
G
gongweibao 已提交
57

W
wanghaoshuang 已提交
58 59 60 61
    stride_height = attrs['stride_height']
    stride_width = attrs['stride_width']
    padding_height = attrs['padding_height']
    padding_width = attrs['padding_width']
G
gongweibao 已提交
62

G
gongweibao 已提交
63 64
    for col_row_idx in range(0, output_height):
        for col_col_idx in range(0, output_width):
G
gongweibao 已提交
65
            for channel in range(0, input_channels):
G
gongweibao 已提交
66 67 68 69
                for filter_row_idx in range(0, filter_height):
                    for filter_col_idx in range(0, filter_width):
                        im_row_offset = col_row_idx * stride_height \
                            + filter_row_idx - padding_height
G
gongweibao 已提交
70

G
gongweibao 已提交
71 72
                        im_col_offset = col_col_idx * stride_width \
                            + filter_col_idx - padding_width
G
gongweibao 已提交
73

G
gongweibao 已提交
74 75
                        if (im_row_offset < 0 or
                                im_row_offset >= input_height or
G
gongweibao 已提交
76
                                im_col_offset < 0 or
G
gongweibao 已提交
77
                                im_col_offset >= input_width):
G
gongweibao 已提交
78
                            col[col_row_idx][col_col_idx][channel][\
G
gongweibao 已提交
79 80
                                filter_row_idx][filter_col_idx] = 0.0
                        else:
G
gongweibao 已提交
81 82
                            im_offset = (channel * input_height + im_row_offset \
                                         ) * input_width + im_col_offset
G
gongweibao 已提交
83 84 85

                            col[col_row_idx][col_col_idx][channel][\
                                filter_row_idx][filter_col_idx] = im[channel][ \
G
gongweibao 已提交
86 87 88
                                    im_row_offset][im_col_offset]


89 90
def Im2Sequence(inputs, attrs):
    output_height, output_width = get_output_shape(attrs, inputs.shape)
W
wanghaoshuang 已提交
91 92 93 94 95 96
    img_channels = inputs.shape[1]
    batch_size = inputs.shape[0]
    out = np.zeros([
        batch_size, output_height, output_width, img_channels,
        attrs['block_height'], attrs['block_width']
    ]).astype("float32")
G
gongweibao 已提交
97

W
wanghaoshuang 已提交
98 99
    for i in range(len(inputs)):
        im2col(attrs, inputs[i], out[i])
G
gongweibao 已提交
100

W
wanghaoshuang 已提交
101 102 103 104 105
    out = out.reshape([
        batch_size * output_height * output_width,
        img_channels * attrs['block_height'] * attrs['block_width']
    ])
    return out
G
gongweibao 已提交
106

G
gongweibao 已提交
107 108

class TestBlockExpandOp(OpTest):
W
wanghaoshuang 已提交
109 110 111 112 113 114 115 116 117 118 119 120
    def config(self):
        self.batch_size = 1
        self.img_channels = 3
        self.img_height = 4
        self.img_width = 4
        self.attrs = {
            'block_height': 2,
            'block_width': 2,
            'stride_height': 1,
            'stride_width': 1,
            'padding_height': 1,
            'padding_width': 1,
G
gongweibao 已提交
121 122
        }

W
wanghaoshuang 已提交
123 124
    def setUp(self):
        self.config()
125 126
        self.op_type = "im2sequence"
        x = np.random.uniform(0.1, 1, [
W
wanghaoshuang 已提交
127 128
            self.batch_size, self.img_channels, self.img_height, self.img_width
        ]).astype("float32")
G
gongweibao 已提交
129

130
        out = Im2Sequence(x, self.attrs)
W
wanghaoshuang 已提交
131 132
        self.inputs = {'X': x}
        self.outputs = {'Out': out}
G
gongweibao 已提交
133 134 135 136 137 138 139 140

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['X'], 'Out')


W
wanghaoshuang 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153
class TestBlockExpandOpCase2(TestBlockExpandOp):
    def config(self):
        self.batch_size = 2
        self.img_channels = 3
        self.img_height = 4
        self.img_width = 5
        self.attrs = {
            'block_height': 2,
            'block_width': 1,
            'stride_height': 2,
            'stride_width': 1,
            'padding_height': 2,
            'padding_width': 1,
G
gongweibao 已提交
154 155
        }

G
gongweibao 已提交
156

W
wanghaoshuang 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170
class TestBlockExpandOpCase3(TestBlockExpandOp):
    def config(self):
        self.batch_size = 3
        self.img_channels = 1
        self.img_height = 4
        self.img_width = 5
        self.attrs = {
            'block_height': 2,
            'block_width': 1,
            'stride_height': 2,
            'stride_width': 1,
            'padding_height': 2,
            'padding_width': 0,
        }
G
gongweibao 已提交
171

G
gongweibao 已提交
172

W
wanghaoshuang 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186
class TestBlockExpandOpCase4(TestBlockExpandOp):
    def config(self):
        self.batch_size = 2
        self.img_channels = 2
        self.img_height = 3
        self.img_width = 3
        self.attrs = {
            'block_height': 2,
            'block_width': 2,
            'stride_height': 1,
            'stride_width': 1,
            'padding_height': 0,
            'padding_width': 0,
        }
G
gongweibao 已提交
187 188 189 190


if __name__ == '__main__':
    unittest.main()