test_grid_sampler_op.py 8.9 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
D
dengkaipeng 已提交
16 17
import unittest
import numpy as np
18
import paddle.fluid.core as core
19
from op_test import OpTest, skip_check_grad_ci
20

21
paddle.enable_static()
D
dengkaipeng 已提交
22 23


24 25 26 27
def AffineGrid(theta, grid_shape):
    n = grid_shape[0]
    h = grid_shape[1]
    w = grid_shape[2]
28 29 30 31 32 33
    h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w,
                      axis=0).T[:, :, np.newaxis]
    w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h,
                      axis=0)[:, :, np.newaxis]
    grid = np.concatenate([w_idx, h_idx, np.ones([h, w, 1])],
                          axis=2)  # h * w * 3
34
    grid = np.repeat(grid[np.newaxis, :], n, axis=0)  # n * h * w *3
D
dengkaipeng 已提交
35 36 37 38 39 40

    ret = np.zeros([n, h * w, 2])
    theta = theta.transpose([0, 2, 1])
    for i in range(len(theta)):
        ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])

41
    return ret.reshape([n, h, w, 2]).astype("float64")
D
dengkaipeng 已提交
42

43

D
dengkaipeng 已提交
44 45 46
def getGridPointValue(data, x, y):
    data_shape = data.shape
    N = data_shape[0]
47 48 49 50 51 52 53 54
    C = data_shape[1]
    in_H = data_shape[2]
    in_W = data_shape[3]
    out_H = x.shape[1]
    out_W = x.shape[2]

    #out = np.zeros(data_shape, dtype='float64')
    out = np.zeros([N, C, out_H, out_W], dtype='float64')
D
dengkaipeng 已提交
55
    for i in range(N):
56 57 58 59
        for j in range(out_H):
            for k in range(out_W):
                if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[
                        i, j, k] < 0 or x[i, j, k] > in_W - 1:
D
dengkaipeng 已提交
60 61 62 63 64 65
                    out[i, :, j, k] = 0
                else:
                    out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]

    return out

66

67 68
def clip(x, min_n, max_n):
    return np.maximum(np.minimum(x, max_n), min_n)
D
dengkaipeng 已提交
69 70


71 72 73 74
def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
    if align_corners:
        grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * max_val)
    else:
75 76
        grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) *
                            (max_val + 1)) - 0.5
77 78 79

    if padding_mode == "border":
        grid_slice = clip(grid_slice, 0, max_val)
80
    elif padding_mode == "reflection":
81 82 83 84 85
        double_range = 2 * max_val if align_corners else (max_val + 1) * 2
        grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
                                                                   0.5)
        extra = grid_abs - np.floor(grid_abs / double_range) * double_range
        grid_slice = np.minimum(extra, double_range - extra)
86 87
        grid_slice = grid_slice if align_corners else clip(
            grid_slice - 0.5, 0, max_val)
88
    return grid_slice
D
dengkaipeng 已提交
89 90


91 92 93 94 95 96 97 98 99 100
def GridSampler(data,
                grid,
                align_corners=True,
                mode="bilinear",
                padding_mode="zeros"):
    dims = data.shape
    N = dims[0]
    in_C = dims[1]
    in_H = dims[2]
    in_W = dims[3]
D
dengkaipeng 已提交
101

102 103
    out_H = grid.shape[1]
    out_W = grid.shape[2]
D
dengkaipeng 已提交
104

105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    x = grid[:, :, :, 0]
    y = grid[:, :, :, 1]
    y_max = in_H - 1
    x_max = in_W - 1

    x = unnormalizeAndClip(x, x_max, align_corners, padding_mode)
    y = unnormalizeAndClip(y, y_max, align_corners, padding_mode)

    if mode == "bilinear":
        x0 = np.floor(x).astype('int32')
        x1 = x0 + 1
        y0 = np.floor(y).astype('int32')
        y1 = y0 + 1

        wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))
        wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)),
                     (1, in_C, 1, 1))

        va = getGridPointValue(data, x0, y0)
        vb = getGridPointValue(data, x0, y1)
        vc = getGridPointValue(data, x1, y0)
        vd = getGridPointValue(data, x1, y1)

        out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64')
    elif mode == "nearest":
        x = np.round(x).astype('int32')
        y = np.round(y).astype('int32')
        out = getGridPointValue(data, x, y)
D
dengkaipeng 已提交
138 139
    return out

140

D
dengkaipeng 已提交
141
class TestGridSamplerOp(OpTest):
142

D
dengkaipeng 已提交
143
    def setUp(self):
144 145
        self.use_cudnn = False
        self.numeric_grad_delta = 0.0001
D
dengkaipeng 已提交
146
        self.op_type = 'grid_sampler'
147 148 149 150
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.initTestCase()
151
        x = np.random.randint(0, 255, self.x_shape).astype('float64')
D
dengkaipeng 已提交
152

153
        theta = np.zeros(self.theta_shape).astype('float64')
D
dengkaipeng 已提交
154 155 156 157
        for i in range(self.theta_shape[0]):
            for j in range(2):
                for k in range(3):
                    theta[i, j, k] = np.random.rand(1)[0]
158
        grid = AffineGrid(theta, self.grid_shape)
D
dengkaipeng 已提交
159 160

        self.inputs = {'X': x, 'Grid': grid}
161 162 163 164 165 166 167
        self.attrs = {
            'use_cudnn': self.use_cudnn,
            "align_corners": self.align_corners,
            "padding_mode": self.padding_mode,
            "mode": self.mode
        }
        self.outputs = {
168 169 170
            'Output':
            GridSampler(x, grid, self.align_corners, self.mode,
                        self.padding_mode)
171
        }
D
dengkaipeng 已提交
172 173

    def test_check_output(self):
174
        self.check_output()
D
dengkaipeng 已提交
175 176

    def test_check_grad_normal(self):
177 178 179 180
        self.check_grad(['X', 'Grid'],
                        'Output',
                        max_relative_error=0.01,
                        numeric_grad_delta=self.numeric_grad_delta)
181 182 183 184 185 186 187 188

    def initTestCase(self):
        self.x_shape = (2, 3, 8, 8)
        self.grid_shape = (2, 7, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
189
        self.use_cudnn = False if core.is_compiled_with_rocm() else True
190 191 192


class Case1(TestGridSamplerOp):
193

194 195 196 197 198 199 200 201 202
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "zeros"
        self.mode = "bilinear"


J
Jiangxinz 已提交
203
class Case1_(TestGridSamplerOp):
204

205 206 207 208 209 210 211 212 213 214
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "border"
        self.mode = "bilinear"


class Case2(TestGridSamplerOp):
215

216 217 218 219 220
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
221
        self.padding_mode = "reflection"
222 223 224 225
        self.mode = "bilinear"


class Case3(TestGridSamplerOp):
226

227 228 229 230 231
    def initTestCase(self):
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
232
        self.padding_mode = "reflection"
233 234
        self.mode = "bilinear"

D
dengkaipeng 已提交
235

236
class Case4(TestGridSamplerOp):
237

D
dengkaipeng 已提交
238
    def initTestCase(self):
239 240
        self.x_shape = (2, 3, 5, 6)
        self.grid_shape = (2, 8, 9, 2)
D
dengkaipeng 已提交
241
        self.theta_shape = (2, 2, 3)
242
        self.align_corners = False
243
        self.padding_mode = "reflection"
244 245
        self.mode = "nearest"
        self.numeric_grad_delta = 0.0001
D
dengkaipeng 已提交
246

247

248 249 250
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class LargeInputCase(TestGridSamplerOp):
251

252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    def get_places(self):
        places = []
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
        return places

    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 128, 128)
        self.grid_shape = (2, 130, 130, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = False
        self.padding_mode = "reflection"
        self.mode = "bilinear"

    def test_check_grad_normal(self):
        pass


@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
                    "however it is desirable to cover the forward pass")
class Case5(LargeInputCase):
274

275 276 277 278 279 280 281 282 283 284 285
    def initTestCase(self):
        self.no_need_check_grad = True
        self.x_shape = (2, 3, 128, 128)
        self.grid_shape = (2, 130, 130, 2)
        self.theta_shape = (2, 2, 3)
        self.align_corners = True
        self.padding_mode = "zeros"
        self.mode = "bilinear"
        self.use_cudnn = False if core.is_compiled_with_rocm() else True


D
dengkaipeng 已提交
286 287
if __name__ == "__main__":
    unittest.main()