test_bilinear_interp_op.py 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core


23 24 25 26 27 28 29
def bilinear_interp_np(input,
                       out_h,
                       out_w,
                       out_size=None,
                       actual_shape=None,
                       align_corners=True,
                       align_mode=0):
30 31 32 33
    """bilinear interpolation implement in shape [N, C, H, W]"""
    if out_size is not None:
        out_h = out_size[0]
        out_w = out_size[1]
34 35 36
    if actual_shape is not None:
        out_h = actual_shape[0]
        out_w = actual_shape[1]
37
    batch_size, channel, in_h, in_w = input.shape
38 39

    ratio_h = ratio_w = 0.0
T
tink2123 已提交
40 41 42 43 44 45 46 47 48 49
    if out_h > 1:
        if (align_corners):
            ratio_h = (in_h - 1.0) / (out_h - 1.0)
        else:
            ratio_h = 1.0 * in_h / out_h
    if out_w > 1:
        if (align_corners):
            ratio_w = (in_w - 1.0) / (out_w - 1.0)
        else:
            ratio_w = 1.0 * in_w / out_w
50 51

    out = np.zeros((batch_size, channel, out_h, out_w))
52

53
    for i in range(out_h):
54 55 56 57 58
        if (align_mode == 0 and not align_corners):
            h = int(ratio_h * (i + 0.5) - 0.5)
        else:
            h = int(ratio_h * i)

T
tink2123 已提交
59
        h = max(0, h)
60
        hid = 1 if h < in_h - 1 else 0
61 62 63 64
        if (align_mode == 0 and not align_corners):
            h1lambda = ratio_h * (i + 0.5) - 0.5 - h
        else:
            h1lambda = ratio_h * i - h
65 66
        h2lambda = 1.0 - h1lambda
        for j in range(out_w):
67 68 69 70
            if (align_mode == 0 and not align_corners):
                w = int(ratio_w * (j + 0.5) - 0.5)
            else:
                w = int(ratio_w * j)
T
tink2123 已提交
71
            w = max(0, w)
72
            wid = 1 if w < in_w - 1 else 0
73 74 75 76
            if (align_mode == 0 and not align_corners):
                w1lambda = ratio_w * (j + 0.5) - 0.5 - w
            else:
                w1lambda = ratio_w * j - w
77 78 79 80 81 82 83 84 85
            w2lambda = 1.0 - w1lambda

            out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
                                        w1lambda*input[:, :, h, w+wid]) + \
                h1lambda*(w2lambda*input[:, :, h+hid, w] +
                          w1lambda*input[:, :, h+hid, w+wid])
    return out.astype(input.dtype)


86
class TestBilinearInterpOp(OpTest):
87
    def setUp(self):
S
sneaxiy 已提交
88
        self.test_gc = True
89
        self.out_size = None
90
        self.actual_shape = None
91
        self.init_test_case()
92
        self.op_type = "bilinear_interp"
93 94
        input_np = np.random.random(self.input_shape).astype("float32")

D
dengkaipeng 已提交
95 96 97 98 99 100 101
        if self.scale > 0:
            out_h = int(self.input_shape[2] * self.scale)
            out_w = int(self.input_shape[3] * self.scale)
        else:
            out_h = self.out_h
            out_w = self.out_w

D
dengkaipeng 已提交
102 103 104
        output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
                                       self.actual_shape, self.align_corners,
                                       self.align_mode)
105 106 107
        self.inputs = {'X': input_np}
        if self.out_size is not None:
            self.inputs['OutSize'] = self.out_size
108 109
        if self.actual_shape is not None:
            self.inputs['OutSize'] = self.actual_shape
D
dengkaipeng 已提交
110

111 112 113
        self.attrs = {
            'out_h': self.out_h,
            'out_w': self.out_w,
D
dengkaipeng 已提交
114
            'scale': self.scale,
115 116 117
            'interp_method': self.interp_method,
            'align_corners': self.align_corners,
            'align_mode': self.align_mode
118 119 120 121 122 123 124 125 126 127 128 129 130 131
        }
        self.outputs = {'Out': output_np}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X'], 'Out', in_place=True)

    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [2, 3, 4, 4]
        self.out_h = 2
        self.out_w = 2
D
dengkaipeng 已提交
132
        self.scale = 0.
133
        self.out_size = np.array([3, 3]).astype("int32")
T
tink2123 已提交
134 135
        self.align_corners = True
        self.align_mode = 1
136 137


138
class TestBilinearInterpCase1(TestBilinearInterpOp):
139 140 141 142 143
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [4, 1, 7, 8]
        self.out_h = 1
        self.out_w = 1
D
dengkaipeng 已提交
144
        self.scale = 0.
T
tink2123 已提交
145 146
        self.align_corners = True
        self.align_mode = 1
147 148


149
class TestBilinearInterpCase2(TestBilinearInterpOp):
150 151 152 153 154
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [3, 3, 9, 6]
        self.out_h = 12
        self.out_w = 12
D
dengkaipeng 已提交
155
        self.scale = 0.
T
tink2123 已提交
156 157
        self.align_corners = True
        self.align_mode = 1
158 159


160
class TestBilinearInterpCase3(TestBilinearInterpOp):
161 162 163 164 165
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [1, 1, 128, 64]
        self.out_h = 64
        self.out_w = 128
D
dengkaipeng 已提交
166
        self.scale = 0.
T
tink2123 已提交
167 168
        self.align_corners = True
        self.align_mode = 1
169 170


171
class TestBilinearInterpCase4(TestBilinearInterpOp):
172 173 174 175 176
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [4, 1, 7, 8]
        self.out_h = 1
        self.out_w = 1
D
dengkaipeng 已提交
177
        self.scale = 0.
178
        self.out_size = np.array([2, 2]).astype("int32")
T
tink2123 已提交
179 180
        self.align_corners = True
        self.align_mode = 1
181 182


183
class TestBilinearInterpCase5(TestBilinearInterpOp):
184 185 186 187 188
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [3, 3, 9, 6]
        self.out_h = 12
        self.out_w = 12
D
dengkaipeng 已提交
189
        self.scale = 0.
190
        self.out_size = np.array([11, 11]).astype("int32")
T
tink2123 已提交
191 192
        self.align_corners = True
        self.align_mode = 1
193 194


195
class TestBilinearInterpCase6(TestBilinearInterpOp):
196 197 198 199 200
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [1, 1, 128, 64]
        self.out_h = 64
        self.out_w = 128
D
dengkaipeng 已提交
201
        self.scale = 0.
202
        self.out_size = np.array([65, 129]).astype("int32")
T
tink2123 已提交
203 204
        self.align_corners = True
        self.align_mode = 1
205 206


207
class TestBilinearInterpActualShape(TestBilinearInterpOp):
208 209 210 211 212
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [3, 2, 32, 16]
        self.out_h = 64
        self.out_w = 32
D
dengkaipeng 已提交
213
        self.scale = 0.
214
        self.out_size = np.array([66, 40]).astype("int32")
T
tink2123 已提交
215 216
        self.align_corners = True
        self.align_mode = 1
217 218


219
class TestBilinearInterpOpUint8(OpTest):
220
    def setUp(self):
S
sneaxiy 已提交
221
        self.test_gc = True
222
        self.out_size = None
223
        self.actual_shape = None
224
        self.init_test_case()
225
        self.op_type = "bilinear_interp"
226 227
        input_np = np.random.randint(
            low=0, high=256, size=self.input_shape).astype("uint8")
D
dengkaipeng 已提交
228 229 230 231 232 233 234 235

        if self.scale > 0:
            out_h = int(self.input_shape[2] * self.scale)
            out_w = int(self.input_shape[3] * self.scale)
        else:
            out_h = self.out_h
            out_w = self.out_w

D
dengkaipeng 已提交
236 237 238
        output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
                                       self.actual_shape, self.align_corners,
                                       self.align_mode)
239 240 241
        self.inputs = {'X': input_np}
        if self.out_size is not None:
            self.inputs['OutSize'] = self.out_size
D
dengkaipeng 已提交
242

243 244 245
        self.attrs = {
            'out_h': self.out_h,
            'out_w': self.out_w,
D
dengkaipeng 已提交
246
            'scale': self.scale,
247 248 249
            'interp_method': self.interp_method,
            'align_corners': self.align_corners,
            'align_mode': self.align_mode
250 251 252 253 254 255 256 257 258 259 260
        }
        self.outputs = {'Out': output_np}

    def test_check_output(self):
        self.check_output_with_place(place=core.CPUPlace(), atol=1)

    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [1, 3, 9, 6]
        self.out_h = 10
        self.out_w = 9
D
dengkaipeng 已提交
261
        self.scale = 0.
T
tink2123 已提交
262 263
        self.align_corners = True
        self.align_mode = 1
264 265


266
class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8):
267 268 269 270 271
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [2, 3, 128, 64]
        self.out_h = 120
        self.out_w = 50
D
dengkaipeng 已提交
272
        self.scale = 0.
T
tink2123 已提交
273 274
        self.align_corners = True
        self.align_mode = 1
275 276


277
class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8):
278 279 280 281 282
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [4, 1, 7, 8]
        self.out_h = 5
        self.out_w = 13
D
dengkaipeng 已提交
283
        self.scale = 0.
284
        self.out_size = np.array([6, 15]).astype("int32")
T
tink2123 已提交
285 286
        self.align_corners = True
        self.align_mode = 1
287 288 289 290 291


class TestBilinearInterpOtherMethod1(TestBilinearInterpOp):
    def set_align_mode(self):
        self.align_corners = False
T
tink2123 已提交
292
        self.align_mode = 1
293 294 295 296


class TestBilinearInterpWithMethod2(TestBilinearInterpOp):
    def set_align_mode(self):
T
tink2123 已提交
297 298
        self.align_corners = False
        self.align_mode = 0
299 300 301 302 303 304


class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
    def set_align_mode(self):
        self.align_corners = True
        self.align_mode = 0
305 306


D
dengkaipeng 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
class TestBilinearInterpScale1(TestBilinearInterpOp):
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [2, 3, 16, 32]
        self.out_h = 60
        self.out_w = 25
        self.scale = 2.
        self.align_corners = True
        self.align_mode = 1


class TestBilinearInterpScale2(TestBilinearInterpOp):
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [2, 3, 16, 32]
        self.out_h = 60
        self.out_w = 25
        self.scale = 1.
        self.align_corners = True
        self.align_mode = 1


class TestBilinearInterpScale3(TestBilinearInterpOp):
    def init_test_case(self):
        self.interp_method = 'bilinear'
        self.input_shape = [2, 3, 16, 32]
        self.out_h = 60
        self.out_w = 25
        self.scale = 1.5
        self.align_corners = True
        self.align_mode = 1


340 341
if __name__ == "__main__":
    unittest.main()