opr_impl.cpp 26.1 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/naive/remap/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "src/naive/remap/opr_impl.h"
#include "src/common/cv/helper.h"
15
#include "src/common/rounding_converter.cuh"
16 17 18 19 20
#include "src/common/utils.h"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;
21
using namespace rounding;
22 23 24 25 26 27

namespace {
template <param::Remap::Format format>
inline int get_offset(int, int, int, int, int, int);

template <>
M
Megvii Engine Team 已提交
28 29
inline int get_offset<param::Remap::Format::NCHW>(
        int height, int width, int channel, int h, int w, int) {
30 31 32 33
    return channel * h * w + height * w + width;
}

template <>
M
Megvii Engine Team 已提交
34 35
inline int get_offset<param::Remap::Format::NHWC>(
        int height, int width, int channel, int, int w, int c) {
36 37 38
    return height * w * c + width * c + channel;
}

39 40 41 42 43 44
template <>
inline int get_offset<param::Remap::Format::NHWCD4>(
        int height, int width, int channel, int, int w, int c) {
    return ((height * c + channel) * w + width) * 4;
}

M
Megvii Engine Team 已提交
45 46 47
template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
48
struct GetSrcData {
M
Megvii Engine Team 已提交
49 50 51
    static inline ctype get(
            const ctype* src, int height, int width, int channel, int h, int w, int c,
            float) {
52 53 54 55
        height = megcv::border_interpolate<bordertype>(height, h);
        width = megcv::border_interpolate<bordertype>(width, w);
        return src[get_offset<format>(height, width, channel, h, w, c)];
    }
M
Megvii Engine Team 已提交
56 57
    static inline int get_index(
            int height, int width, int channel, int h, int w, int c) {
58 59 60 61
        height = megcv::border_interpolate<bordertype>(height, h);
        width = megcv::border_interpolate<bordertype>(width, w);
        return get_offset<format>(height, width, channel, h, w, c);
    }
62 63
};

64 65
template <typename ctype, param::Remap::Format format>
struct GetSrcData<ctype, format, param::Remap::BorderMode::CONSTANT> {
M
Megvii Engine Team 已提交
66 67 68
    static inline ctype get(
            const ctype* src, int height, int width, int channel, int h, int w, int c,
            float scalar) {
69
        RoundingConverter<ctype> round;
70
        return (height >= 0 && height < h && width >= 0 && width < w)
M
Megvii Engine Team 已提交
71 72
                     ? src[get_offset<format>(height, width, channel, h, w, c)]
                     : round(scalar);
73
    }
M
Megvii Engine Team 已提交
74 75
    static inline int get_index(
            int height, int width, int channel, int h, int w, int c) {
76
        return (height >= 0 && height < h && width >= 0 && width < w)
M
Megvii Engine Team 已提交
77 78
                     ? get_offset<format>(height, width, channel, h, w, c)
                     : -1;
79 80 81
    }
};

M
Megvii Engine Team 已提交
82 83 84 85 86 87
template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
void remap_LINEAR(
        const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
        int OH, int OW, float scalar) {
88
    RoundingConverter<ctype> round_converter;
89 90 91 92 93 94
    size_t c_scale = 1;
    if (format == param::Remap::Format::NHWCD4) {
        c_scale = 4;
    }
    for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW,
             dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) {
95 96 97 98 99 100
        for (int h = 0; h < OH; ++h) {
            for (int w = 0; w < OW; ++w) {
                float index_col = map_xy[h * OW * 2 + w * 2 + 0];
                float index_row = map_xy[h * OW * 2 + w * 2 + 1];
                int col = static_cast<int>(floor(index_col));
                int row = static_cast<int>(floor(index_row));
101 102 103
                float v = index_col - col;  // alphaw
                float u = index_row - row;  // alphah
                const float one = 1.f;
104
                for (int c = 0; c < C; ++c) {
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
                    if (format == param::Remap::Format::NHWCD4) {
                        int idx00 = GetSrcData<ctype, format, bordertype>::get_index(
                                row + 0, col + 0, c, IH, IW, C);
                        int idx01 = GetSrcData<ctype, format, bordertype>::get_index(
                                row + 0, col + 1, c, IH, IW, C);
                        int idx10 = GetSrcData<ctype, format, bordertype>::get_index(
                                row + 1, col + 0, c, IH, IW, C);
                        int idx11 = GetSrcData<ctype, format, bordertype>::get_index(
                                row + 1, col + 1, c, IH, IW, C);
                        for (int c_inner = 0; c_inner < 4; ++c_inner) {
                            ctype a00 = (idx00 != -1) ? src[idx00 + c_inner]
                                                      : round_converter(scalar);
                            ctype a01 = (idx01 != -1) ? src[idx01 + c_inner]
                                                      : round_converter(scalar);
                            ctype a10 = (idx10 != -1) ? src[idx10 + c_inner]
                                                      : round_converter(scalar);
                            ctype a11 = (idx11 != -1) ? src[idx11 + c_inner]
                                                      : round_converter(scalar);
                            dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] =
                                    round_converter(
                                            a00 * (one - v) * (one - u) +
                                            a01 * (one - u) * v + a10 * (one - v) * u +
                                            a11 * u * v);
                        }
                    } else {
                        ctype a00 = GetSrcData<ctype, format, bordertype>::get(
                                src, row + 0, col + 0, c, IH, IW, C, scalar);
                        ctype a01 = GetSrcData<ctype, format, bordertype>::get(
                                src, row + 0, col + 1, c, IH, IW, C, scalar);
                        ctype a10 = GetSrcData<ctype, format, bordertype>::get(
                                src, row + 1, col + 0, c, IH, IW, C, scalar);
                        ctype a11 = GetSrcData<ctype, format, bordertype>::get(
                                src, row + 1, col + 1, c, IH, IW, C, scalar);

                        dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter(
                                a00 * (one - v) * (one - u) + a01 * (one - u) * v +
                                a10 * (one - v) * u + a11 * u * v);
                    }
                }
            }
        }
    }
}

namespace {

inline float round_half_to_even(float f) {
    const float round_away_from_zero = std::round(f);
    const float diff = round_away_from_zero - f;

    if ((diff != 0.5f) && (diff != -0.5f)) {
        return round_away_from_zero;
    }

    if (std::fmod(round_away_from_zero, 2.0f) == 0.0f) {
        return round_away_from_zero;
    }

    return f - diff;
}

}  // anonymous namespace

template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
void remap_NEAREST(
        const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
        int OH, int OW, float scalar) {
    RoundingConverter<ctype> round_converter;
    size_t c_scale = 1;
    if (format == param::Remap::Format::NHWCD4) {
        c_scale = 4;
    }
    for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW,
             dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) {
        for (int h = 0; h < OH; ++h) {
            for (int w = 0; w < OW; ++w) {
                float index_col = map_xy[h * OW * 2 + w * 2 + 0];
                float index_row = map_xy[h * OW * 2 + w * 2 + 1];
                int col = static_cast<int>(round_half_to_even(index_col));
                int row = static_cast<int>(round_half_to_even(index_row));
                for (int c = 0; c < C; ++c) {
                    if (format == param::Remap::Format::NHWCD4) {
                        int idx = GetSrcData<ctype, format, bordertype>::get_index(
                                row, col, c, IH, IW, C);
                        for (int c_inner = 0; c_inner < 4; ++c_inner) {
                            dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] =
                                    (idx != -1) ? (src[idx + c_inner])
                                                : round_converter(scalar);
                        }
                    } else {
                        dst[get_offset<format>(h, w, c, OH, OW, C)] =
                                GetSrcData<ctype, format, bordertype>::get(
                                        src, row, col, c, IH, IW, C, scalar);
                    }
201 202 203 204 205 206
                }
            }
        }
    }
}

M
Megvii Engine Team 已提交
207 208 209 210 211 212
template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
void remap_LINEAR_backwarddata(
        ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
        int IW, int OH, int OW) {
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    RoundingConverter<ctype> round_converter;
    std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW);
    for (int n = 0; n < N;
         ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) {
        for (int h = 0; h < OH; ++h) {
            for (int w = 0; w < OW; ++w) {
                float index_col = map_xy[h * OW * 2 + w * 2 + 0];
                float index_row = map_xy[h * OW * 2 + w * 2 + 1];
                int col = static_cast<int>(floor(index_col));
                int row = static_cast<int>(floor(index_row));
                float v = index_col - col;  // alphaw
                float u = index_row - row;  // alphah
                const float one = 1.f;
                for (int c = 0; c < C; ++c) {
                    ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)];
228

229 230 231
                    int a00 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 0, col + 0, c, IH, IW, C);
                    if (a00 != -1) {
M
Megvii Engine Team 已提交
232
                        grad[a00] += round_converter((one - v) * (one - u) * hidden);
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
                    }

                    int a01 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 0, col + 1, c, IH, IW, C);
                    if (a01 != -1) {
                        grad[a01] += round_converter((one - u) * v * hidden);
                    }

                    int a10 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 1, col + 0, c, IH, IW, C);
                    if (a10 != -1) {
                        grad[a10] += round_converter(u * (one - v) * hidden);
                    }

                    int a11 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 1, col + 1, c, IH, IW, C);
                    if (a11 != -1) {
                        grad[a11] += round_converter(v * u * hidden);
                    }
                }
            }
        }
    }
}

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
void remap_NEAREST_backwarddata(
        ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
        int IW, int OH, int OW) {
    std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW);
    for (int n = 0; n < N;
         ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) {
        for (int h = 0; h < OH; ++h) {
            for (int w = 0; w < OW; ++w) {
                float index_col = map_xy[h * OW * 2 + w * 2 + 0];
                float index_row = map_xy[h * OW * 2 + w * 2 + 1];
                int col = static_cast<int>(round_half_to_even(index_col));
                int row = static_cast<int>(round_half_to_even(index_row));
                for (int c = 0; c < C; ++c) {
                    ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)];
                    int idx = GetSrcData<ctype, format, bordertype>::get_index(
                            row, col, c, IH, IW, C);
                    if (idx != -1) {
                        grad[idx] += hidden;
                    }
                }
            }
        }
    }
}

M
Megvii Engine Team 已提交
286 287 288 289 290 291
template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
void remap_LINEAR_backwardmat(
        const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
        int C, int IH, int IW, int OH, int OW, float scalar) {
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
    std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
    for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW,
             map_xy += OH * OW * 2, grad += OH * OW * 2) {
        for (int h = 0; h < OH; ++h) {
            for (int w = 0; w < OW; ++w) {
                float index_col = map_xy[h * OW * 2 + w * 2 + 0];
                float index_row = map_xy[h * OW * 2 + w * 2 + 1];
                int col = static_cast<int>(floor(index_col));
                int row = static_cast<int>(floor(index_row));
                float v = index_col - col;  // alphaw
                float u = index_row - row;  // alphah
                const float one = 1.f;
                for (int c = 0; c < C; ++c) {
                    float hidden = static_cast<float>(
                            diff[get_offset<format>(h, w, c, OH, OW, C)]);
                    float du = 0.f, dv = 0.f;

                    int a00 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 0, col + 0, c, IH, IW, C);
                    int a01 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 0, col + 1, c, IH, IW, C);
                    int a10 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 1, col + 0, c, IH, IW, C);
                    int a11 = GetSrcData<ctype, format, bordertype>::get_index(
                            row + 1, col + 1, c, IH, IW, C);

318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
                    dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) *
                          (one - u);
                    dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) *
                          (one - u);
                    dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u;
                    dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u;

                    du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) *
                          (one - v);
                    du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v;
                    du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) *
                          (one - v);
                    du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v;

                    grad[h * OW * 2 + w * 2 + 0] += hidden * dv;
                    grad[h * OW * 2 + w * 2 + 1] += hidden * du;
334 335 336 337 338
                }
            }
        }
    }
}
339

340 341 342 343 344 345 346 347 348 349
template <
        typename ctype, param::Remap::Format format,
        param::Remap::BorderMode bordertype>
void remap_NEAREST_backwardmat(
        const ctype*, const float*, const ctype*, float* grad, int N, int, int, int,
        int OH, int OW, float) {
    std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
    return;
}

350 351
}  // namespace

M
Megvii Engine Team 已提交
352 353 354
void RemapImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
        _megdnn_workspace workspace) {
355 356 357 358 359 360 361
    check_exec(src.layout, map_xy.layout, dst.layout, workspace.size);
    int N, C, IH, IW, OH, OW;
    if (param().format == param::Remap::Format::NCHW) {
        N = src.layout.shape[0];
        C = src.layout.shape[1];
        IH = src.layout.shape[2];
        IW = src.layout.shape[3];
362
    } else if (param().format == param::Remap::Format::NHWC) {
363 364 365 366
        N = src.layout.shape[0];
        C = src.layout.shape[3];
        IH = src.layout.shape[1];
        IW = src.layout.shape[2];
367 368 369 370 371
    } else if (param().format == param::Remap::Format::NHWCD4) {
        N = src.layout.shape[0];
        C = src.layout.shape[2];
        IH = src.layout.shape[1];
        IW = src.layout.shape[3];
372 373
    } else {
        megdnn_throw("unsupported format");
374 375 376 377
    }
    OH = map_xy.layout.shape[1];
    OW = map_xy.layout.shape[2];
    switch (src.layout.dtype.enumv()) {
M
Megvii Engine Team 已提交
378 379 380 381 382 383 384 385 386 387 388
#define cb(dt, fmt, border, interpolation)                                            \
    if (param().format == param::Remap::Format::fmt &&                                \
        param().border_type == param::Remap::BorderMode::border &&                    \
        param().imode == param::Remap::InterpolationMode::interpolation) {            \
        using ctype = DTypeTrait<dt>::ctype;                                          \
        MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation<                          \
                                      ctype, param::Remap::Format::fmt,               \
                                      param::Remap::BorderMode::border>(              \
                src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(),     \
                dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, param().scalar))); \
        break;                                                                        \
389 390 391 392 393 394 395 396 397
    }

#define support_dtype(dt)                                                   \
    case DTypeTrait<dt>::enumv: {                                           \
        cb(dt, NCHW, CONSTANT, LINEAR);                                     \
        cb(dt, NCHW, REPLICATE, LINEAR);                                    \
        cb(dt, NCHW, REFLECT, LINEAR);                                      \
        cb(dt, NCHW, REFLECT_101, LINEAR);                                  \
        cb(dt, NCHW, WRAP, LINEAR);                                         \
398 399 400 401 402
        cb(dt, NHWCD4, CONSTANT, LINEAR);                                   \
        cb(dt, NHWCD4, REPLICATE, LINEAR);                                  \
        cb(dt, NHWCD4, REFLECT, LINEAR);                                    \
        cb(dt, NHWCD4, REFLECT_101, LINEAR);                                \
        cb(dt, NHWCD4, WRAP, LINEAR);                                       \
403 404 405 406 407
        cb(dt, NHWC, CONSTANT, LINEAR);                                     \
        cb(dt, NHWC, REPLICATE, LINEAR);                                    \
        cb(dt, NHWC, REFLECT, LINEAR);                                      \
        cb(dt, NHWC, REFLECT_101, LINEAR);                                  \
        cb(dt, NHWC, WRAP, LINEAR);                                         \
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
        cb(dt, NCHW, CONSTANT, NEAREST);                                    \
        cb(dt, NCHW, REPLICATE, NEAREST);                                   \
        cb(dt, NCHW, REFLECT, NEAREST);                                     \
        cb(dt, NCHW, REFLECT_101, NEAREST);                                 \
        cb(dt, NCHW, WRAP, NEAREST);                                        \
        cb(dt, NHWCD4, CONSTANT, NEAREST);                                  \
        cb(dt, NHWCD4, REPLICATE, NEAREST);                                 \
        cb(dt, NHWCD4, REFLECT, NEAREST);                                   \
        cb(dt, NHWCD4, REFLECT_101, NEAREST);                               \
        cb(dt, NHWCD4, WRAP, NEAREST);                                      \
        cb(dt, NHWC, CONSTANT, NEAREST);                                    \
        cb(dt, NHWC, REPLICATE, NEAREST);                                   \
        cb(dt, NHWC, REFLECT, NEAREST);                                     \
        cb(dt, NHWC, REFLECT_101, NEAREST);                                 \
        cb(dt, NHWC, WRAP, NEAREST);                                        \
423 424 425 426 427 428
        megdnn_throw(                                                       \
                "format, border type or imode is incorrect in remap navie " \
                "with dtype = " #dt);                                       \
    }

        support_dtype(dtype::Float32);
M
Megvii Engine Team 已提交
429 430
        DNN_INC_FLOAT16(support_dtype(dtype::Float16));
        DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
431 432 433 434 435 436 437 438 439
        support_dtype(dtype::Int8);
        support_dtype(dtype::Uint8);
#undef cb
#undef support_dtype

        default:
            megdnn_throw("unsupported dtype in remap naive\n");
    }
}
440

M
Megvii Engine Team 已提交
441 442 443
void RemapBackwardDataImpl::exec(
        _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        _megdnn_workspace workspace) {
444
    check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
M
Megvii Engine Team 已提交
445 446 447
    megdnn_assert(
            param().format == param::Remap::Format::NCHW,
            "only support NCHW format for remap backward");
448 449 450 451 452 453 454 455
    int N, C, IH, IW, OH, OW;
    N = grad.layout.shape[0];
    C = grad.layout.shape[1];
    IH = grad.layout.shape[2];
    IW = grad.layout.shape[3];
    OH = map_xy.layout.shape[1];
    OW = map_xy.layout.shape[2];
    switch (diff.layout.dtype.enumv()) {
M
Megvii Engine Team 已提交
456 457 458 459 460 461 462 463 464 465 466
#define cb(dt, fmt, border, interpolation)                                         \
    if (param().format == param::Remap::Format::fmt &&                             \
        param().border_type == param::Remap::BorderMode::border &&                 \
        param().imode == param::Remap::InterpolationMode::interpolation) {         \
        using ctype = DTypeTrait<dt>::ctype;                                       \
        MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwarddata<        \
                                      ctype, param::Remap::Format::fmt,            \
                                      param::Remap::BorderMode::border>(           \
                grad.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \
                diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW)));             \
        break;                                                                     \
467 468 469 470 471 472 473 474 475
    }

#define support_dtype(dt)                                                   \
    case DTypeTrait<dt>::enumv: {                                           \
        cb(dt, NCHW, CONSTANT, LINEAR);                                     \
        cb(dt, NCHW, REPLICATE, LINEAR);                                    \
        cb(dt, NCHW, REFLECT, LINEAR);                                      \
        cb(dt, NCHW, REFLECT_101, LINEAR);                                  \
        cb(dt, NCHW, WRAP, LINEAR);                                         \
476 477 478 479 480
        cb(dt, NCHW, CONSTANT, NEAREST);                                    \
        cb(dt, NCHW, REPLICATE, NEAREST);                                   \
        cb(dt, NCHW, REFLECT, NEAREST);                                     \
        cb(dt, NCHW, REFLECT_101, NEAREST);                                 \
        cb(dt, NCHW, WRAP, NEAREST);                                        \
481 482 483 484 485 486
        megdnn_throw(                                                       \
                "format, border type or imode is incorrect in remap navie " \
                "with dtype = " #dt);                                       \
    }

        support_dtype(dtype::Float32);
M
Megvii Engine Team 已提交
487
        DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
488
        DNN_INC_FLOAT16(support_dtype(dtype::Float16));
489 490 491 492 493 494 495 496
#undef cb
#undef support_dtype

        default:
            megdnn_throw("unsupported dtype in remap backward naive\n");
    }
}

M
Megvii Engine Team 已提交
497 498 499 500 501 502 503
void RemapBackwardMatImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
        _megdnn_tensor_out grad, _megdnn_workspace workspace) {
    check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size);
    megdnn_assert(
            param().format == param::Remap::Format::NCHW,
            "only support NCHW format for remap backward");
504 505 506 507 508 509 510 511
    int N, C, IH, IW, OH, OW;
    N = src.layout.shape[0];
    C = src.layout.shape[1];
    IH = src.layout.shape[2];
    IW = src.layout.shape[3];
    OH = map_xy.layout.shape[1];
    OW = map_xy.layout.shape[2];
    switch (src.layout.dtype.enumv()) {
M
Megvii Engine Team 已提交
512 513 514 515 516 517 518 519 520 521 522 523
#define cb(dt, fmt, border, interpolation)                                             \
    if (param().format == param::Remap::Format::fmt &&                                 \
        param().border_type == param::Remap::BorderMode::border &&                     \
        param().imode == param::Remap::InterpolationMode::interpolation) {             \
        using ctype = DTypeTrait<dt>::ctype;                                           \
        MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwardmat<             \
                                      ctype, param::Remap::Format::fmt,                \
                                      param::Remap::BorderMode::border>(               \
                src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(),      \
                diff.compatible_ptr<ctype>(), grad.compatible_ptr<dt_float32>(), N, C, \
                IH, IW, OH, OW, param().scalar)));                                     \
        break;                                                                         \
524 525 526 527 528 529 530 531 532
    }

#define support_dtype(dt)                                                   \
    case DTypeTrait<dt>::enumv: {                                           \
        cb(dt, NCHW, CONSTANT, LINEAR);                                     \
        cb(dt, NCHW, REPLICATE, LINEAR);                                    \
        cb(dt, NCHW, REFLECT, LINEAR);                                      \
        cb(dt, NCHW, REFLECT_101, LINEAR);                                  \
        cb(dt, NCHW, WRAP, LINEAR);                                         \
533 534 535 536 537
        cb(dt, NCHW, CONSTANT, NEAREST);                                    \
        cb(dt, NCHW, REPLICATE, NEAREST);                                   \
        cb(dt, NCHW, REFLECT, NEAREST);                                     \
        cb(dt, NCHW, REFLECT_101, NEAREST);                                 \
        cb(dt, NCHW, WRAP, NEAREST);                                        \
538 539 540 541 542 543
        megdnn_throw(                                                       \
                "format, border type or imode is incorrect in remap navie " \
                "with dtype = " #dt);                                       \
    }

        support_dtype(dtype::Float32);
M
Megvii Engine Team 已提交
544
        DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
545
        DNN_INC_FLOAT16(support_dtype(dtype::Float16));
546 547 548 549 550 551 552 553 554
#undef cb
#undef support_dtype

        default:
            megdnn_throw("unsupported dtype in remap backward naive\n");
    }
}

// vim: syntax=cpp.doxygen