test_grid_sampler_op.py 19.2 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
D
dengkaipeng 已提交
16 17
import unittest
import numpy as np
18
import paddle.fluid.core as core
19
from op_test import OpTest, skip_check_grad_ci
20

21
paddle.enable_static()
D
dengkaipeng 已提交
22

23 24 25 26 27 28 29 30 31
from white_list import (
    op_accuracy_white_list,
    check_shape_white_list,
    compile_vs_runtime_white_list,
    no_check_set_white_list,
    op_threshold_white_list,
    no_grad_set_white_list,
)

D
dengkaipeng 已提交
32

33 34 35 36
def AffineGrid(theta, grid_shape):
    n = grid_shape[0]
    h = grid_shape[1]
    w = grid_shape[2]
37 38 39 40 41 42
    h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w,
                      axis=0).T[:, :, np.newaxis]
    w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h,
                      axis=0)[:, :, np.newaxis]
    grid = np.concatenate([w_idx, h_idx, np.ones([h, w, 1])],
                          axis=2)  # h * w * 3
43
    grid = np.repeat(grid[np.newaxis, :], n, axis=0)  # n * h * w *3
D
dengkaipeng 已提交
44 45 46 47 48 49

    ret = np.zeros([n, h * w, 2])
    theta = theta.transpose([0, 2, 1])
    for i in range(len(theta)):
        ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])

50
    return ret.reshape([n, h, w, 2]).astype("float64")
D
dengkaipeng 已提交
51

52

D
dengkaipeng 已提交
53 54 55
def getGridPointValue(data, x, y):
    data_shape = data.shape
    N = data_shape[0]
56 57 58 59 60 61 62 63
    C = data_shape[1]
    in_H = data_shape[2]
    in_W = data_shape[3]
    out_H = x.shape[1]
    out_W = x.shape[2]

    #out = np.zeros(data_shape, dtype='float64')
    out = np.zeros([N, C, out_H, out_W], dtype='float64')
D
dengkaipeng 已提交
64
    for i in range(N):
65 66 67 68
        for j in range(out_H):
            for k in range(out_W):
                if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[
                        i, j, k] < 0 or x[i, j, k] > in_W - 1:
D
dengkaipeng 已提交
69 70 71 72 73 74
                    out[i, :, j, k] = 0
                else:
                    out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]

    return out

75

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
def AffineGrid3D(theta, grid_shape):
    n = grid_shape[0]
    d = grid_shape[1]
    h = grid_shape[2]
    w = grid_shape[3]
    d_idx = np.repeat(np.repeat(np.linspace(-1, 1, d)[:, np.newaxis,
                                                      np.newaxis],
                                h,
                                axis=1),
                      w,
                      axis=2)[:, :, :, np.newaxis]
    h_idx = np.repeat(np.repeat(np.linspace(-1, 1, h)[np.newaxis, :,
                                                      np.newaxis],
                                w,
                                axis=2),
                      d,
                      axis=0)[:, :, :, np.newaxis]
    w_idx = np.repeat(np.repeat(np.linspace(-1, 1, w)[np.newaxis,
                                                      np.newaxis, :],
                                h,
                                axis=1),
                      d,
                      axis=0)[:, :, :, np.newaxis]
    grid = np.concatenate(
        [w_idx, h_idx, d_idx, np.ones([d, h, w, 1])], axis=3)  # d * h * w * 4
    grid = np.repeat(grid[np.newaxis, :], n, axis=0)  # n * d * h * w *4
    ret = np.zeros([n, d * h * w, 3])
    theta = theta.transpose([0, 2, 1])
    for i in range(len(theta)):
        ret[i] = np.dot(grid[i].reshape([d * h * w, 4]), theta[i])

    return ret.reshape([n, d, h, w, 3]).astype("float64")


def getGridPointValue3D(data, x, y, z):
    data_shape = data.shape
    N = data_shape[0]
    C = data_shape[1]
    in_D = data_shape[2]
    in_H = data_shape[3]
    in_W = data_shape[4]
    out_D = x.shape[1]
    out_H = x.shape[2]
    out_W = x.shape[3]

    #out = np.zeros(data_shape, dtype='float64')
    out = np.zeros([N, C, out_D, out_H, out_W], dtype='float64')
    for i in range(N):
        for j in range(out_D):
            for k in range(out_H):
                for l in range(out_W):
                    if y[i, j, k, l] < 0 or y[i, j, k, l] > in_H - 1 or x[
                            i, j, k, l] < 0 or x[i, j, k, l] > in_W - 1 or z[
                                i, j, k, l] < 0 or z[i, j, k, l] > in_D - 1:
                        out[i, :, j, k, l] = 0
                    else:
                        out[i, :, j, k, l] = data[i, :, z[i, j, k, l],
                                                  y[i, j, k, l], x[i, j, k, l]]

    return out


138 139
def clip(x, min_n, max_n):
    return np.maximum(np.minimum(x, max_n), min_n)
D
dengkaipeng 已提交
140 141


142 143 144 145
def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
    if align_corners:
        grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * max_val)
    else:
146 147
        grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) *
                            (max_val + 1)) - 0.5
148 149 150

    if padding_mode == "border":
        grid_slice = clip(grid_slice, 0, max_val)
151
    elif padding_mode == "reflection":
152 153 154 155 156
        double_range = 2 * max_val if align_corners else (max_val + 1) * 2
        grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
                                                                   0.5)
        extra = grid_abs - np.floor(grid_abs / double_range) * double_range
        grid_slice = np.minimum(extra, double_range - extra)
157 158
        grid_slice = grid_slice if align_corners else clip(
            grid_slice - 0.5, 0, max_val)
159
    return grid_slice
D
dengkaipeng 已提交
160 161


162 163 164 165 166 167 168 169 170 171
def GridSampler(data,
                grid,
                align_corners=True,
                mode="bilinear",
                padding_mode="zeros"):
    dims = data.shape
    N = dims[0]
    in_C = dims[1]
    in_H = dims[2]
    in_W = dims[3]
D
dengkaipeng 已提交
172

173 174
    out_H = grid.shape[1]
    out_W = grid.shape[2]
D
dengkaipeng 已提交
175

176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    x = grid[:, :, :, 0]
    y = grid[:, :, :, 1]
    y_max = in_H - 1
    x_max = in_W - 1

    x = unnormalizeAndClip(x, x_max, align_corners, padding_mode)
    y = unnormalizeAndClip(y, y_max, align_corners, padding_mode)

    if mode == "bilinear":
        x0 = np.floor(x).astype('int32')
        x1 = x0 + 1
        y0 = np.floor(y).astype('int32')
        y1 = y0 + 1

        wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))

        va = getGridPointValue(data, x0, y0)
        vb = getGridPointValue(data, x0, y1)
        vc = getGridPointValue(data, x1, y0)
        vd = getGridPointValue(data, x1, y1)

        out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64')
    elif mode == "nearest":
        x = np.round(x).astype('int32')
        y = np.round(y).astype('int32')
        out = getGridPointValue(data, x, y)
D
dengkaipeng 已提交
209 210
    return out

211

212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
def GridSampler3D(data,
                  grid,
                  align_corners=True,
                  mode="bilinear",
                  padding_mode="zeros"):
    dims = data.shape
    N = dims[0]
    in_C = dims[1]
    in_D = dims[2]
    in_H = dims[3]
    in_W = dims[4]

    out_D = grid.shape[1]
    out_H = grid.shape[2]
    out_W = grid.shape[3]

    x = grid[:, :, :, :, 0]
    y = grid[:, :, :, :, 1]
    z = grid[:, :, :, :, 2]

    z_max = in_D - 1
    y_max = in_H - 1
    x_max = in_W - 1

    x = unnormalizeAndClip(x, x_max, align_corners, padding_mode)
    y = unnormalizeAndClip(y, y_max, align_corners, padding_mode)
    z = unnormalizeAndClip(z, z_max, align_corners, padding_mode)

    if mode == "bilinear":
        x0 = np.floor(x).astype('int32')
        x1 = x0 + 1
        y0 = np.floor(y).astype('int32')
        y1 = y0 + 1
        z0 = np.floor(z).astype('int32')
        z1 = z0 + 1

        w_tnw = np.tile(((x1 - x) * (y1 - y) * (z1 - z)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_tne = np.tile(((x - x0) * (y1 - y) * (z1 - z)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_tsw = np.tile(((x1 - x) * (y - y0) * (z1 - z)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_tse = np.tile(((x - x0) * (y - y0) * (z1 - z)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_bnw = np.tile(((x1 - x) * (y1 - y) * (z - z0)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_bne = np.tile(((x - x0) * (y1 - y) * (z - z0)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_bsw = np.tile(((x1 - x) * (y - y0) * (z - z0)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))
        w_bse = np.tile(((x - x0) * (y - y0) * (z - z0)).reshape(
            (N, 1, out_D, out_H, out_W)), (1, in_C, 1, 1, 1))

        v_tnw = getGridPointValue3D(data, x0, y0, z0)
        v_tne = getGridPointValue3D(data, x1, y0, z0)
        v_tsw = getGridPointValue3D(data, x0, y1, z0)
        v_tse = getGridPointValue3D(data, x1, y1, z0)
        v_bnw = getGridPointValue3D(data, x0, y0, z1)
        v_bne = getGridPointValue3D(data, x1, y0, z1)
        v_bsw = getGridPointValue3D(data, x0, y1, z1)
        v_bse = getGridPointValue3D(data, x1, y1, z1)

        out = (w_tnw * v_tnw + w_tne * v_tne + w_tsw * v_tsw + w_tse * v_tse +
               w_bnw * v_bnw + w_bne * v_bne + w_bsw * v_bsw +
               w_bse * v_bse).astype('float64')

    elif mode == "nearest":
        x = np.round(x).astype('int32')
        y = np.round(y).astype('int32')
        z = np.round(z).astype('int32')
        out = getGridPointValue3D(data, x, y, z)
    return out


D
dengkaipeng 已提交
286
class TestGridSamplerOp(OpTest):
287

D
dengkaipeng 已提交
288
    def setUp(self):
289 290
        self.use_cudnn = False
        self.numeric_grad_delta = 0.0001
D
dengkaipeng 已提交
291
        self.op_type = 'grid_sampler'
W
Wang Bojun 已提交
292
        self.python_api = paddle.nn.functional.grid_sample
293 294 295 296
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.initTestCase()
297 298
        x = np.random.randint(0, 255, self.x_shape).astype('float64')
        theta = np.zeros(self.theta_shape).astype('float64')
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341

        if len(self.grid_shape) == 4:
            for i in range(self.theta_shape[0]):
                for j in range(2):
                    for k in range(3):
                        theta[i, j, k] = np.random.rand(1)[0]
            grid = AffineGrid(theta, self.grid_shape)
            self.inputs = {'X': x, 'Grid': grid}
            self.attrs = {
                'use_cudnn': self.use_cudnn,
                "align_corners": self.align_corners,
                "padding_mode": self.padding_mode,
                "mode": self.mode
            }
            self.outputs = {
                'Output':
                GridSampler(x, grid, self.align_corners, self.mode,
                            self.padding_mode)
            }
        else:
            for i in range(self.theta_shape[0]):
                for j in range(3):
                    for k in range(4):
                        theta[i, j, k] = np.random.rand(1)[0]
            grid = AffineGrid3D(theta, self.grid_shape)
            self.inputs = {'X': x, 'Grid': grid}
            self.attrs = {
                'use_cudnn': self.use_cudnn,
                "align_corners": self.align_corners,
                "padding_mode": self.padding_mode,
                "mode": self.mode
            }
            self.outputs = {
                'Output':
                GridSampler3D(x, grid, self.align_corners, self.mode,
                              self.padding_mode)
            }

    def get_places(self):
        places = []
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
        return places
D
dengkaipeng 已提交
342 343

    def test_check_output(self):
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
        if len(self.grid_shape) == 4:
            self.check_output(check_eager=True)
        else:
            check_eager_flag = True
            check_dygraph_flag = False
            for place in self.get_places():
                res = self.check_output_with_place(
                    place,
                    atol=1e-5,
                    check_dygraph=check_dygraph_flag,
                    check_eager=check_eager_flag)
                if check_eager_flag:
                    assert check_dygraph_flag == False
                    outs, eager_dygraph_outs, fetch_list = res
                elif check_dygraph_flag:
                    uts, dygraph_outs, fetch_list = res
                else:
                    outs, fetch_list = res
                if self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST:
                    self.check_compile_vs_runtime(fetch_list, outs)
D
dengkaipeng 已提交
364 365

    def test_check_grad_normal(self):
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
        if len(self.grid_shape) == 4:
            self.check_grad(['X', 'Grid'],
                            'Output',
                            max_relative_error=0.01,
                            numeric_grad_delta=self.numeric_grad_delta,
                            check_eager=True)
        else:
            self._check_grad_helper()
            for place in self.get_places():
                self.check_grad_with_place(
                    place, ['X'],
                    'Output',
                    numeric_grad_delta=self.numeric_grad_delta,
                    max_relative_error=0.01,
                    check_eager=True,
                    check_dygraph=False)
382 383 384 385 386 387 388 389

    def initTestCase(self):
        self.x_shape = (2, 3, 8, 8)
        self.grid_shape = (2, 7, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
390
        self.use_cudnn = False if core.is_compiled_with_rocm() else True
391 392 393


class Case1(TestGridSamplerOp):
394

395 396 397 398 399 400 401 402 403
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "zeros"
        self.mode = "bilinear"


J
Jiangxinz 已提交
404
class Case1_(TestGridSamplerOp):
405

406 407 408 409 410 411 412 413 414 415
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "border"
        self.mode = "bilinear"


class Case2(TestGridSamplerOp):
416

417 418 419 420 421
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
422
        self.padding_mode = "reflection"
423 424 425 426
        self.mode = "bilinear"


class Case3(TestGridSamplerOp):
427

428 429 430 431 432
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
433
        self.padding_mode = "reflection"
434 435
        self.mode = "bilinear"

D
dengkaipeng 已提交
436

437
class Case4(TestGridSamplerOp):
438

D
dengkaipeng 已提交
439
    def initTestCase(self):
440 441
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
D
dengkaipeng 已提交
442
        self.theta_shape = (2, 2, 3)
443
        self.align_corners = False
444
        self.padding_mode = "reflection"
445 446
        self.mode = "nearest"
        self.numeric_grad_delta = 0.0001
D
dengkaipeng 已提交
447

448

449 450 451
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class LargeInputCase(TestGridSamplerOp):
452

453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
    def get_places(self):
        places = []
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
        return places

    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 128, 128)
        self.grid_shape = (2, 130, 130, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "reflection"
        self.mode = "bilinear"

    def test_check_grad_normal(self):
        pass


@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class Case5(LargeInputCase):
475

476 477 478 479 480 481 482 483 484 485 486
    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 128, 128)
        self.grid_shape = (2, 130, 130, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.use_cudnn = False if core.is_compiled_with_rocm() else True


487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
class Case6(TestGridSamplerOp):

    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6, 7)
        self.grid_shape = (2, 8, 9, 10, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = False
        self.padding_mode = "zeros"
        self.mode = "bilinear"


class Case6_(TestGridSamplerOp):

    def get_places(self):
        places = []
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
        return places

    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6, 7)
        self.grid_shape = (2, 8, 9, 10, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = False
        self.padding_mode = "border"
        self.mode = "bilinear"


class Case7(TestGridSamplerOp):

    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6, 7)
        self.grid_shape = (2, 8, 9, 10, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = False
        self.padding_mode = "reflection"
        self.mode = "bilinear"


class Case8(TestGridSamplerOp):

    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6, 7)
        self.grid_shape = (2, 8, 9, 10, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = True
        self.padding_mode = "reflection"
        self.mode = "bilinear"


class Case9(TestGridSamplerOp):

    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6, 7)
        self.grid_shape = (2, 8, 9, 10, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = False
        self.padding_mode = "reflection"
        self.mode = "nearest"
        self.numeric_grad_delta = 0.0001


@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class LargeInput3DCase(TestGridSamplerOp):

    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 24, 24, 12)
        self.grid_shape = (2, 25, 25, 12, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = False
        self.padding_mode = "reflection"
        self.mode = "bilinear"
        self.use_cudnn = False
        self.__class__.op_type = 'grid_sampler'

    def test_check_grad_normal(self):
        pass


@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class Case10(LargeInput3DCase):

    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 24, 24, 12)
        self.grid_shape = (2, 25, 25, 12, 3)
        self.theta_shape = (2, 3, 4)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.use_cudnn = False
        self.__class__.op_type = 'grid_sampler'


D
dengkaipeng 已提交
584 585
if __name__ == "__main__":
    unittest.main()