test_grid_sampler_op.py 8.8 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
D
dengkaipeng 已提交
16 17
import unittest
import numpy as np
18
import paddle.fluid.core as core
19 20
from op_test import OpTest, skip_check_grad_ci
paddle.enable_static()
D
dengkaipeng 已提交
21 22


23 24 25 26
def AffineGrid(theta, grid_shape):
    n = grid_shape[0]
    h = grid_shape[1]
    w = grid_shape[2]
D
dengkaipeng 已提交
27
    h_idx = np.repeat(
28
        np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
D
dengkaipeng 已提交
29
    w_idx = np.repeat(
30
        np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
D
dengkaipeng 已提交
31
    grid = np.concatenate(
32
        [w_idx, h_idx, np.ones([h, w, 1])], axis=2)  # h * w * 3
33
    grid = np.repeat(grid[np.newaxis, :], n, axis=0)  # n * h * w *3
D
dengkaipeng 已提交
34 35 36 37 38 39

    ret = np.zeros([n, h * w, 2])
    theta = theta.transpose([0, 2, 1])
    for i in range(len(theta)):
        ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])

40
    return ret.reshape([n, h, w, 2]).astype("float64")
D
dengkaipeng 已提交
41

42

D
dengkaipeng 已提交
43 44 45
def getGridPointValue(data, x, y):
    data_shape = data.shape
    N = data_shape[0]
46 47 48 49 50 51 52 53
    C = data_shape[1]
    in_H = data_shape[2]
    in_W = data_shape[3]
    out_H = x.shape[1]
    out_W = x.shape[2]

    #out = np.zeros(data_shape, dtype='float64')
    out = np.zeros([N, C, out_H, out_W], dtype='float64')
D
dengkaipeng 已提交
54
    for i in range(N):
55 56 57 58
        for j in range(out_H):
            for k in range(out_W):
                if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[
                        i, j, k] < 0 or x[i, j, k] > in_W - 1:
D
dengkaipeng 已提交
59 60 61 62 63 64
                    out[i, :, j, k] = 0
                else:
                    out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]

    return out

65

66 67
def clip(x, min_n, max_n):
    return np.maximum(np.minimum(x, max_n), min_n)
D
dengkaipeng 已提交
68 69


70 71 72 73 74 75 76 77 78
def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
    if align_corners:
        grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * max_val)
    else:
        grid_slice = 0.5 * (
            (grid_slice.astype('float64') + 1.0) * (max_val + 1)) - 0.5

    if padding_mode == "border":
        grid_slice = clip(grid_slice, 0, max_val)
79
    elif padding_mode == "reflection":
80 81 82 83 84 85 86 87
        double_range = 2 * max_val if align_corners else (max_val + 1) * 2
        grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
                                                                   0.5)
        extra = grid_abs - np.floor(grid_abs / double_range) * double_range
        grid_slice = np.minimum(extra, double_range - extra)
        grid_slice = grid_slice if align_corners else clip(grid_slice - 0.5, 0,
                                                           max_val)
    return grid_slice
D
dengkaipeng 已提交
88 89


90 91 92 93 94 95 96 97 98 99
def GridSampler(data,
                grid,
                align_corners=True,
                mode="bilinear",
                padding_mode="zeros"):
    dims = data.shape
    N = dims[0]
    in_C = dims[1]
    in_H = dims[2]
    in_W = dims[3]
D
dengkaipeng 已提交
100

101 102
    out_H = grid.shape[1]
    out_W = grid.shape[2]
D
dengkaipeng 已提交
103

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    x = grid[:, :, :, 0]
    y = grid[:, :, :, 1]
    y_max = in_H - 1
    x_max = in_W - 1

    x = unnormalizeAndClip(x, x_max, align_corners, padding_mode)
    y = unnormalizeAndClip(y, y_max, align_corners, padding_mode)

    if mode == "bilinear":
        x0 = np.floor(x).astype('int32')
        x1 = x0 + 1
        y0 = np.floor(y).astype('int32')
        y1 = y0 + 1

        wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))

        va = getGridPointValue(data, x0, y0)
        vb = getGridPointValue(data, x0, y1)
        vc = getGridPointValue(data, x1, y0)
        vd = getGridPointValue(data, x1, y1)

        out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64')
    elif mode == "nearest":
        x = np.round(x).astype('int32')
        y = np.round(y).astype('int32')
        out = getGridPointValue(data, x, y)
D
dengkaipeng 已提交
137 138
    return out

139

D
dengkaipeng 已提交
140 141
class TestGridSamplerOp(OpTest):
    def setUp(self):
142 143
        self.use_cudnn = False
        self.numeric_grad_delta = 0.0001
D
dengkaipeng 已提交
144
        self.op_type = 'grid_sampler'
145 146 147 148
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.initTestCase()
149
        x = np.random.randint(0, 255, self.x_shape).astype('float64')
D
dengkaipeng 已提交
150

151
        theta = np.zeros(self.theta_shape).astype('float64')
D
dengkaipeng 已提交
152 153 154 155
        for i in range(self.theta_shape[0]):
            for j in range(2):
                for k in range(3):
                    theta[i, j, k] = np.random.rand(1)[0]
156
        grid = AffineGrid(theta, self.grid_shape)
D
dengkaipeng 已提交
157 158

        self.inputs = {'X': x, 'Grid': grid}
159 160 161 162 163 164 165 166 167 168
        self.attrs = {
            'use_cudnn': self.use_cudnn,
            "align_corners": self.align_corners,
            "padding_mode": self.padding_mode,
            "mode": self.mode
        }
        self.outputs = {
            'Output': GridSampler(x, grid, self.align_corners, self.mode,
                                  self.padding_mode)
        }
D
dengkaipeng 已提交
169 170

    def test_check_output(self):
171
        self.check_output()
D
dengkaipeng 已提交
172 173

    def test_check_grad_normal(self):
174 175 176 177 178 179 180 181 182 183 184 185 186
        self.check_grad(
            ['X', 'Grid'],
            'Output',
            max_relative_error=0.01,
            numeric_grad_delta=self.numeric_grad_delta)

    def initTestCase(self):
        self.x_shape = (2, 3, 8, 8)
        self.grid_shape = (2, 7, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
187
        self.use_cudnn = False if core.is_compiled_with_rocm() else True
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215


class Case1(TestGridSamplerOp):
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "zeros"
        self.mode = "bilinear"


class Case1(TestGridSamplerOp):
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "border"
        self.mode = "bilinear"


class Case2(TestGridSamplerOp):
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
216
        self.padding_mode = "reflection"
217 218 219 220 221 222 223 224 225
        self.mode = "bilinear"


class Case3(TestGridSamplerOp):
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
226
        self.padding_mode = "reflection"
227 228
        self.mode = "bilinear"

D
dengkaipeng 已提交
229

230
class Case4(TestGridSamplerOp):
D
dengkaipeng 已提交
231
    def initTestCase(self):
232 233
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
D
dengkaipeng 已提交
234
        self.theta_shape = (2, 2, 3)
235
        self.align_corners = False
236
        self.padding_mode = "reflection"
237 238
        self.mode = "nearest"
        self.numeric_grad_delta = 0.0001
D
dengkaipeng 已提交
239

240

241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class LargeInputCase(TestGridSamplerOp):
    def get_places(self):
        places = []
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
        return places

    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 128, 128)
        self.grid_shape = (2, 130, 130, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "reflection"
        self.mode = "bilinear"

    def test_check_grad_normal(self):
        pass


@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class Case5(LargeInputCase):
    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 128, 128)
        self.grid_shape = (2, 130, 130, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.use_cudnn = False if core.is_compiled_with_rocm() else True


D
dengkaipeng 已提交
277 278
if __name__ == "__main__":
    unittest.main()