backward_mat.cpp 3.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/remap/backward_mat.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "src/cuda/remap/common.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

M
Megvii Engine Team 已提交
20 21 22 23 24
void RemapBackwardMatImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
        _megdnn_tensor_out grad, _megdnn_workspace workspace) {
    check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size);
    megdnn_assert(
25 26 27
            (param().imode == param::Remap::InterpolationMode::NEAREST) ||
                    (param().imode == param::Remap::InterpolationMode::LINEAR),
            "only support NEAREST and LINEAR interpolationMode");
M
Megvii Engine Team 已提交
28 29 30
    megdnn_assert(
            param().format == param::Remap::Format::NCHW,
            "only support NCHW format for remap backward");
31 32 33 34 35 36 37 38 39
    auto stream = cuda_stream(this->handle());
    int N, C, IH, IW, OH, OW;
    N = src.layout.shape[0];
    C = src.layout.shape[1];
    IH = src.layout.shape[2];
    IW = src.layout.shape[3];
    OH = map_xy.layout.shape[1];
    OW = map_xy.layout.shape[2];

40
#define cb(dt, _format, bmode, inter_mode)                                             \
M
Megvii Engine Team 已提交
41
    if (param().format == param::Remap::Format::_format &&                             \
42 43
        param().border_type == param::Remap::BorderMode::bmode &&                      \
        param().imode == param::Remap::InterpolationMode::inter_mode) {                \
M
Megvii Engine Team 已提交
44 45 46
        using ctype = DTypeTrait<dt>::ctype;                                           \
        remap::backwardmat_proxy<                                                      \
                ctype, param_enumv::Remap::Format::_format,                            \
47 48
                ::BorderMode::BORDER_##bmode,                                          \
                ::InterpolationMode::INTER_##inter_mode>(                              \
M
Megvii Engine Team 已提交
49 50 51 52
                src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(),      \
                diff.compatible_ptr<ctype>(), grad.compatible_ptr<dt_float32>(), N, C, \
                IH, IW, OH, OW, param().scalar, stream);                               \
        break;                                                                         \
53 54 55 56
    }

#define support_dtype(dt)                                      \
    case DTypeTrait<dt>::enumv: {                              \
57 58 59 60 61 62 63 64 65 66
        cb(dt, NCHW, CONSTANT, NEAREST);                       \
        cb(dt, NCHW, REPLICATE, NEAREST);                      \
        cb(dt, NCHW, REFLECT, NEAREST);                        \
        cb(dt, NCHW, REFLECT_101, NEAREST);                    \
        cb(dt, NCHW, WRAP, NEAREST);                           \
        cb(dt, NCHW, CONSTANT, LINEAR);                        \
        cb(dt, NCHW, REPLICATE, LINEAR);                       \
        cb(dt, NCHW, REFLECT, LINEAR);                         \
        cb(dt, NCHW, REFLECT_101, LINEAR);                     \
        cb(dt, NCHW, WRAP, LINEAR);                            \
67 68 69 70 71 72
        megdnn_throw("unsupported border type in remap cuda"); \
    }

    switch (src.layout.dtype.enumv()) {
        support_dtype(dtype::Float32);
        support_dtype(dtype::BFloat16);
73
        support_dtype(dtype::Float16);
74 75 76 77 78 79 80 81 82
        default:
            megdnn_throw("unsupported dtype in remap backward cuda\n");
    }

#undef support_dtype
#undef cb
}

// vim: syntax=cpp.doxygen