test_nearest_interp_op.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core


23 24 25 26
def nearest_neighbor_interp_np(X,
                               out_h,
                               out_w,
                               out_size=None,
27 28
                               actual_shape=None,
                               align_corners=True):
29 30 31 32
    """nearest neighbor interpolation implement in shape [N, C, H, W]"""
    if out_size is not None:
        out_h = out_size[0]
        out_w = out_size[1]
33 34 35
    if actual_shape is not None:
        out_h = actual_shape[0]
        out_w = actual_shape[1]
36 37 38
    n, c, in_h, in_w = X.shape

    ratio_h = ratio_w = 0.0
T
tink2123 已提交
39 40 41 42 43 44 45 46 47 48
    if (out_h > 1):
        if (align_corners):
            ratio_h = (in_h - 1.0) / (out_h - 1.0)
        else:
            ratio_h = 1.0 * in_h / out_h
    if (out_w > 1):
        if (align_corners):
            ratio_w = (in_w - 1.0) / (out_w - 1.0)
        else:
            ratio_w = 1.0 * in_w / out_w
49 50

    out = np.zeros((n, c, out_h, out_w))
51 52 53 54 55 56 57 58 59 60 61 62 63

    if align_corners:
        for i in range(out_h):
            in_i = int(ratio_h * i + 0.5)
            for j in range(out_w):
                in_j = int(ratio_w * j + 0.5)
                out[:, :, i, j] = X[:, :, in_i, in_j]
    else:
        for i in range(out_h):
            in_i = int(ratio_h * i)
            for j in range(out_w):
                in_j = int(ratio_w * j)
                out[:, :, i, j] = X[:, :, in_i, in_j]
64 65 66 67

    return out.astype(X.dtype)


68
class TestNearestInterpOp(OpTest):
69 70
    def setUp(self):
        self.out_size = None
71
        self.actual_shape = None
72
        self.init_test_case()
73
        self.op_type = "nearest_interp"
74 75
        input_np = np.random.random(self.input_shape).astype("float32")

D
dengkaipeng 已提交
76 77 78 79 80 81 82 83
        if self.scale > 0:
            out_h = int(self.input_shape[2] * self.scale)
            out_w = int(self.input_shape[3] * self.scale)
        else:
            out_h = self.out_h
            out_w = self.out_w

        output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
84 85
                                               self.out_size, self.actual_shape,
                                               self.align_corners)
86 87 88
        self.inputs = {'X': input_np}
        if self.out_size is not None:
            self.inputs['OutSize'] = self.out_size
89 90
        if self.actual_shape is not None:
            self.inputs['OutSize'] = self.actual_shape
91 92 93
        self.attrs = {
            'out_h': self.out_h,
            'out_w': self.out_w,
D
dengkaipeng 已提交
94
            'scale': self.scale,
95 96
            'interp_method': self.interp_method,
            'align_corners': self.align_corners,
97 98 99 100 101 102 103 104 105 106
        }
        self.outputs = {'Out': output_np}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X'], 'Out', in_place=True)

    def init_test_case(self):
107
        self.interp_method = 'nearest'
108 109 110
        self.input_shape = [2, 3, 4, 4]
        self.out_h = 2
        self.out_w = 2
D
dengkaipeng 已提交
111
        self.scale = 0.
112
        self.out_size = np.array([3, 3]).astype("int32")
113
        self.align_corners = True
114 115


116
class TestNearestNeighborInterpCase1(TestNearestInterpOp):
117
    def init_test_case(self):
118
        self.interp_method = 'nearest'
119 120 121
        self.input_shape = [4, 1, 7, 8]
        self.out_h = 1
        self.out_w = 1
D
dengkaipeng 已提交
122
        self.scale = 0.
T
tink2123 已提交
123
        self.align_corners = True
124 125


126
class TestNearestNeighborInterpCase2(TestNearestInterpOp):
127
    def init_test_case(self):
128
        self.interp_method = 'nearest'
129 130 131
        self.input_shape = [3, 3, 9, 6]
        self.out_h = 12
        self.out_w = 12
D
dengkaipeng 已提交
132
        self.scale = 0.
133
        self.align_corners = True
134 135


136
class TestNearestNeighborInterpCase3(TestNearestInterpOp):
137
    def init_test_case(self):
138
        self.interp_method = 'nearest'
139 140 141
        self.input_shape = [1, 1, 128, 64]
        self.out_h = 64
        self.out_w = 128
D
dengkaipeng 已提交
142
        self.scale = 0.
143
        self.align_corners = True
144 145


146
class TestNearestNeighborInterpCase4(TestNearestInterpOp):
147
    def init_test_case(self):
148
        self.interp_method = 'nearest'
149 150 151
        self.input_shape = [4, 1, 7, 8]
        self.out_h = 1
        self.out_w = 1
D
dengkaipeng 已提交
152
        self.scale = 0.
153
        self.out_size = np.array([2, 2]).astype("int32")
154
        self.align_corners = True
155 156


157
class TestNearestNeighborInterpCase5(TestNearestInterpOp):
158
    def init_test_case(self):
159
        self.interp_method = 'nearest'
160 161 162
        self.input_shape = [3, 3, 9, 6]
        self.out_h = 12
        self.out_w = 12
D
dengkaipeng 已提交
163
        self.scale = 0.
164
        self.out_size = np.array([11, 11]).astype("int32")
165
        self.align_corners = True
166 167


168
class TestNearestNeighborInterpCase6(TestNearestInterpOp):
169
    def init_test_case(self):
170
        self.interp_method = 'nearest'
171 172 173
        self.input_shape = [1, 1, 128, 64]
        self.out_h = 64
        self.out_w = 128
D
dengkaipeng 已提交
174
        self.scale = 0.
175
        self.out_size = np.array([65, 129]).astype("int32")
176
        self.align_corners = True
177 178


179
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
180
    def init_test_case(self):
181
        self.interp_method = 'nearest'
182 183 184
        self.input_shape = [3, 2, 32, 16]
        self.out_h = 64
        self.out_w = 32
D
dengkaipeng 已提交
185
        self.scale = 0.
186
        self.out_size = np.array([66, 40]).astype("int32")
187
        self.align_corners = True
188 189


190
class TestNearestInterpOpUint8(OpTest):
191 192
    def setUp(self):
        self.out_size = None
193
        self.actual_shape = None
194
        self.init_test_case()
195
        self.op_type = "nearest_interp"
196 197
        input_np = np.random.randint(
            low=0, high=256, size=self.input_shape).astype("uint8")
D
dengkaipeng 已提交
198 199 200 201 202 203 204 205 206

        if self.scale > 0:
            out_h = int(self.input_shape[2] * self.scale)
            out_w = int(self.input_shape[3] * self.scale)
        else:
            out_h = self.out_h
            out_w = self.out_w

        output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
207 208
                                               self.out_size, self.actual_shape,
                                               self.align_corners)
209 210 211 212 213 214
        self.inputs = {'X': input_np}
        if self.out_size is not None:
            self.inputs['OutSize'] = self.out_size
        self.attrs = {
            'out_h': self.out_h,
            'out_w': self.out_w,
D
dengkaipeng 已提交
215
            'scale': self.scale,
216 217
            'interp_method': self.interp_method,
            'align_corners': self.align_corners
218 219 220 221 222 223 224
        }
        self.outputs = {'Out': output_np}

    def test_check_output(self):
        self.check_output_with_place(place=core.CPUPlace(), atol=1)

    def init_test_case(self):
225
        self.interp_method = 'nearest'
226 227 228
        self.input_shape = [1, 3, 9, 6]
        self.out_h = 10
        self.out_w = 9
D
dengkaipeng 已提交
229
        self.scale = 0.
230
        self.align_corners = True
231 232


233
class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8):
234 235 236 237 238
    def init_test_case(self):
        self.interp_method = 'nearest'
        self.input_shape = [2, 3, 128, 64]
        self.out_h = 120
        self.out_w = 50
D
dengkaipeng 已提交
239
        self.scale = 0.
T
tink2123 已提交
240
        self.align_corners = True
241 242


243
class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8):
244 245 246 247 248
    def init_test_case(self):
        self.interp_method = 'nearest'
        self.input_shape = [4, 1, 7, 8]
        self.out_h = 5
        self.out_w = 13
D
dengkaipeng 已提交
249
        self.scale = 0.
250
        self.out_size = np.array([6, 15]).astype("int32")
251 252 253 254 255 256
        self.align_corners = True


class TestNearestInterpWithoutCorners(TestNearestInterpOp):
    def set_align_corners(self):
        self.align_corners = False
257 258


D
dengkaipeng 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
class TestNearestNeighborInterpScale1(TestNearestInterpOp):
    def init_test_case(self):
        self.interp_method = 'nearest'
        self.input_shape = [3, 2, 32, 16]
        self.out_h = 64
        self.out_w = 32
        self.scale = 2.
        self.out_size = np.array([66, 40]).astype("int32")
        self.align_corners = True


class TestNearestNeighborInterpScale2(TestNearestInterpOp):
    def init_test_case(self):
        self.interp_method = 'nearest'
        self.input_shape = [3, 2, 32, 16]
        self.out_h = 64
        self.out_w = 32
        self.scale = 1.5
        self.out_size = np.array([66, 40]).astype("int32")
        self.align_corners = True


class TestNearestNeighborInterpScale3(TestNearestInterpOp):
    def init_test_case(self):
        self.interp_method = 'nearest'
        self.input_shape = [3, 2, 32, 16]
        self.out_h = 64
        self.out_w = 32
        self.scale = 1.
        self.out_size = np.array([66, 40]).astype("int32")
        self.align_corners = True


292 293
if __name__ == "__main__":
    unittest.main()