backward_mat.cu 6.4 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/remap/backward_mat.cu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include <cuda_runtime.h>
#include "src/common/rounding_converter.cuh"
#include "src/cuda/cv/kernel_common.cuh"
#include "src/cuda/remap/common.h"
#include "src/cuda/utils.cuh"

using namespace megdnn;
using namespace cuda;
using namespace remap;
using namespace rounding;

namespace {

template <const uint32_t format>
M
Megvii Engine Team 已提交
26 27
__device__ inline int get_offset(
        int height, int width, int channel, int h, int w, int c);
28 29 30 31 32 33 34 35 36

template <>
__device__ inline int get_offset<param_enumv::Remap::Format::NCHW>(
        int height, int width, int channel, int h, int w, int c) {
    return channel * h * w + height * w + width;
}

template <typename ctype, const uint32_t format, ::BorderMode bmode>
struct GetSrcData {
M
Megvii Engine Team 已提交
37 38
    __device__ static inline int get_index(
            int height, int width, int channel, int h, int w, int c) {
39 40 41 42 43 44 45 46
        height = megcv::border_interpolate<bmode>(height, h);
        width = megcv::border_interpolate<bmode>(width, w);
        return get_offset<format>(height, width, channel, h, w, c);
    }
};

template <typename ctype, const uint32_t format>
struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
M
Megvii Engine Team 已提交
47 48
    __device__ static inline int get_index(
            int height, int width, int channel, int h, int w, int c) {
49
        return (height >= 0 && height < h && width >= 0 && width < w)
M
Megvii Engine Team 已提交
50 51
                     ? get_offset<format>(height, width, channel, h, w, c)
                     : -1;
52 53 54 55
    }
};

template <typename ctype, const uint32_t format, ::BorderMode bmode>
M
Megvii Engine Team 已提交
56 57 58
__global__ void kern_general(
        const ctype* src, const float* map_xy, const ctype* diff,
        float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) {
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    src += blockIdx.z * C * IH * IW;
    diff += blockIdx.z * C * OH * OW;
    map_xy += blockIdx.z * 2 * OH * OW;
    grad += blockIdx.z * 2 * OH * OW;
    RoundingConverter<ctype> round_converter;

    if (ow < OW && oh < OH) {
        float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
        float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
        int col = static_cast<int>(floor(index_col));
        int row = static_cast<int>(floor(index_row));
        float v = index_col - col;  // alphaw
        float u = index_row - row;  // alphah
        const float one = 1.f;
        for (int c = 0; c < C; ++c) {
M
Megvii Engine Team 已提交
76 77
            float hidden =
                    static_cast<float>(diff[get_offset<format>(oh, ow, c, OH, OW, C)]);
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            float du = 0.f, dv = 0.f;

            int a00 = GetSrcData<ctype, format, bmode>::get_index(
                    row + 0, col + 0, c, IH, IW, C);
            int a01 = GetSrcData<ctype, format, bmode>::get_index(
                    row + 0, col + 1, c, IH, IW, C);
            int a10 = GetSrcData<ctype, format, bmode>::get_index(
                    row + 1, col + 0, c, IH, IW, C);
            int a11 = GetSrcData<ctype, format, bmode>::get_index(
                    row + 1, col + 1, c, IH, IW, C);

            dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
            dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
            dv -= ((a10 != -1) ? src[a10] : scalar) * u;
            dv += ((a11 != -1) ? src[a11] : scalar) * u;

            du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
            du -= ((a01 != -1) ? src[a01] : scalar) * v;
            du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
            du += ((a11 != -1) ? src[a11] : scalar) * v;

            grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv);
            grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du);
        }
    }
}

template <typename ctype, const uint32_t format, ::BorderMode bmode>
M
Megvii Engine Team 已提交
106 107 108
void dispatch_backwardmat(
        const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
        int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) {
109 110 111 112 113 114 115 116
    const int BX = 32, BY = 16;
    const int max_batch_size = 65535;
    while (N) {
        size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
        dim3 threads(BX, BY);
        dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);

        cuda_check(cudaMemsetAsync(
M
Megvii Engine Team 已提交
117
                grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream));
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
                src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);

        N -= curr_batch_size;
        src += curr_batch_size * C * IH * IW;
        diff += curr_batch_size * C * OH * OW;
        map_xy += curr_batch_size * 2 * OH * OW;
        grad += curr_batch_size * 2 * OH * OW;
    }
}

}  // anonymous namespace

namespace megdnn {
namespace cuda {
namespace remap {

template <typename ctype, const uint32_t format, ::BorderMode bmode>
M
Megvii Engine Team 已提交
136 137 138 139 140
void backwardmat_proxy(
        const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
        int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) {
    dispatch_backwardmat<ctype, format, bmode>(
            src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream);
141 142 143
    after_kernel_launch();
}

M
Megvii Engine Team 已提交
144 145 146 147 148
#define INST(ctype, format, bmode)                                                     \
    template void                                                                      \
    backwardmat_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
            const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \
            int, float, cudaStream_t);
149 150 151 152 153 154 155 156 157

#define FOR_FORMAT_BMODE(ctype)           \
    INST(ctype, NCHW, BORDER_CONSTANT)    \
    INST(ctype, NCHW, BORDER_REPLICATE)   \
    INST(ctype, NCHW, BORDER_REFLECT)     \
    INST(ctype, NCHW, BORDER_REFLECT_101) \
    INST(ctype, NCHW, BORDER_WRAP)

FOR_FORMAT_BMODE(float)
M
Megvii Engine Team 已提交
158
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
159
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
160 161 162 163 164 165 166 167 168

#undef FOR_FORMAT_BMODE
#undef INST

}  // namespace remap
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen