forward.cu 12.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/resize/forward.cu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
11 12
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
13 14
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"
15
#include "src/cuda/resize/resize_cv.cuh"
16

17 18
#include "src/cuda/cv/kernel_common.cuh"
#include "src/common/resize.cuh"
19 20 21

using namespace megdnn;
using namespace cuda;
22 23 24
using namespace megdnn::cuda::resize;
using megdnn::resize::interpolate_cubic;
using megdnn::megcv::saturate;
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace {

template <typename ctype>
struct DirectSrcVisitor {
    const ctype* ptr;

    __device__ __forceinline__ const ctype* get(int batch, int im_size) {
        return ptr + batch * im_size;
    }

    void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; }
};

template <typename ctype, typename SrcVisitor, typename OutputConverter>
40 41 42 43
__global__ void kern_general_linear(SrcVisitor src, ctype* __restrict dst,
                                    int C, int IH, int IW, int OH, int OW,
                                    int S_IN, int S_IC, int S_IH, int S_IW,
                                    float scale_h, float scale_w) {
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    OutputConverter output_converter;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
    dst += blockIdx.z * C * OH * OW;

    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        for (int c = 0; c < C; ++c) {
            dst[oh * OW + ow] = output_converter(
                    sptr[ih0 * S_IH + iw0 * S_IW] * (1.0f - alphaw) *
                            (1.0f - alphah) +
                    sptr[ih0 * S_IH + iw1 * S_IW] * alphaw * (1.0f - alphah) +
                    sptr[ih1 * S_IH + iw0 * S_IW] * (1.0f - alphaw) * alphah +
                    sptr[ih1 * S_IH + iw1 * S_IW] * alphaw * alphah);

            sptr += S_IC;
            dst += OH * OW;
        }
    }
}

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
                                     int C, int IH, int IW, int OH, int OW,
                                     int S_IN, int S_IC, int S_IH, int S_IW,
                                     float scale_h, float scale_w) {
    OutputConverter output_converter;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
    dst += blockIdx.z * C * OH * OW;

    if (ow < OW && oh < OH) {
        int ih = get_nearest_src(scale_h, IH, oh);
        int iw = get_nearest_src(scale_w, IW, ow);

        for (int c = 0; c < C; ++c) {
89
            dst[oh * OW + ow] = output_converter(sptr[ih * S_IH + iw * S_IW]);
90 91 92 93 94 95 96

            sptr += S_IC;
            dst += OH * OW;
        }
    }
}

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_cubic(SrcVisitor src, ctype* __restrict dst, int C,
                                   int IH, int IW, int OH, int OW, int S_IN,
                                   int S_IC, int S_IH, int S_IW, float scale_h,
                                   float scale_w) {
    OutputConverter output_converter;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
    dst += blockIdx.z * C * OH * OW;

    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
        ih0--;
        iw0--;
        float h_coeff[4], w_coeff[4];
        interpolate_cubic(alphah, h_coeff);
        interpolate_cubic(alphaw, w_coeff);
        for (int c = 0; c < C; ++c) {
            float ret = 0;
            constexpr int ksize = 4;
            for (int kh = 0; kh < ksize; kh++) {
                int ih = saturate(ih0 + kh, 0, IH - 1);
                for (int kw = 0; kw < ksize; kw++) {
                    int iw = saturate(iw0 + kw, 0, IW - 1);
                    ret += sptr[ih * S_IH + iw * S_IW] * h_coeff[kh] *
                           w_coeff[kw];
                }
            }
            dst[oh * OW + ow] = output_converter(ret);

            sptr += S_IC;
            dst += OH * OW;
        }
    }
}
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
                                  int IH, int IW, int OH, int OW, float scale_h,
                                  float scale_w) {
    OutputConverter output_converter;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW);
    dst += blockIdx.z * C * OH * OW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        for (int c = 0; c < C; ++c) {
            dst[(oh * OW + ow) * C + c] = output_converter(
                    sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) *
                            (1.0f - alphah) +
                    sptr[(ih0 * IW + iw1) * C + c] * alphaw * (1.0f - alphah) +
                    sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) * alphah +
                    sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah);
        }
    }
}

template <typename ctype, typename SrcVisitor>
166 167 168 169
void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
                           SrcVisitor src, ctype* dst, int N, int C, int IH,
                           int IW, int OH, int OW, int S_IN, int S_IC, int S_IH,
                           int S_IW, cudaStream_t stream) {
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    const int BY = 16, BX = 32;

    const int max_batch_size = 65535;
    while (N) {
        size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
        dim3 threads(BX, BY);
        dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);

        float scale_h = static_cast<float>(OH) / IH;
        float scale_w = static_cast<float>(OW) / IW;
        if (is_nhwc) {
            kern_general_nhwc<ctype, SrcVisitor,
                              rounding::RoundingConverter<ctype>>
                    <<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH,
                                                     OW, scale_h, scale_w);
        } else {
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
            switch (imode) {
                case InterpolationMode::INTER_LINEAR:
                    kern_general_linear<ctype, SrcVisitor,
                                        rounding::RoundingConverter<ctype>>
                            <<<blocks, threads, 0, stream>>>(
                                    src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
                                    S_IH, S_IW, scale_h, scale_w);
                    break;
                case InterpolationMode::INTER_NEAREST:
                    kern_general_nearest<ctype, SrcVisitor,
                                         rounding::RoundingConverter<ctype>>
                            <<<blocks, threads, 0, stream>>>(
                                    src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
                                    S_IH, S_IW, scale_h, scale_w);
                    break;
                case InterpolationMode::INTER_CUBIC:
                    kern_general_cubic<ctype, SrcVisitor,
                                       rounding::RoundingConverter<ctype>>
                            <<<blocks, threads, 0, stream>>>(
                                    src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
                                    S_IH, S_IW, scale_h, scale_w);
                    break;
                default:
                    megdnn_throw("unsupported interpolation mode");
                    break;
211
            }
212 213 214 215 216 217 218 219 220
        }
        N -= curr_batch_size;
        src.move_batch(curr_batch_size, C * IH * IW);
        dst += curr_batch_size * C * OH * OW;
    }
}

template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C,
221 222
                                   int IH, int IW, int OH, int OW,
                                   float scale_h, float scale_w) {
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
    OutputConverter output_converter;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW);
    dst += blockIdx.z * C * OH * OW;

    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        int o_coor = (oh * OW + ow) << 2;
        int i_coor00 = (ih0 * IW + iw0) << 2;
        int i_coor01 = (ih0 * IW + iw1) << 2;
        int i_coor10 = (ih1 * IW + iw0) << 2;
        int i_coor11 = (ih1 * IW + iw1) << 2;
        for (int c0 = 0, nr_chan = C >> 2; c0 < nr_chan; ++c0) {
#pragma unroll
            for (int c1 = 0; c1 < 4; ++c1) {
                dst[o_coor + c1] = output_converter(
247 248 249 250 251
                        sptr[i_coor00 + c1] * (1.0f - alphaw) *
                                (1.0f - alphah) +
                        sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) +
                        sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah +
                        sptr[i_coor11 + c1] * alphaw * alphah);
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
            }
            dst += OH * OW * 4;
            sptr += IH * IW * 4;
        }
    }
}

template <typename ctype, typename SrcVisitor>
void dispatch_with_visitor_nchw4(SrcVisitor src, ctype* dst, int N, int C,
                                 int IH, int IW, int OH, int OW,
                                 cudaStream_t stream) {
    const int BY = 16, BX = 32;

    const int max_batch_size = 65535;
    while (N) {
        size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
        dim3 threads(BX, BY);
        dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);

        float scale_h = static_cast<float>(OH) / IH;
        float scale_w = static_cast<float>(OW) / IW;
        kern_general_nchw4<ctype, SrcVisitor,
                           rounding::RoundingConverter<ctype>>
                <<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, OW,
                                                 scale_h, scale_w);
        N -= curr_batch_size;
        src.move_batch(curr_batch_size, C * IH * IW);
        dst += curr_batch_size * C * OH * OW;
    }
}

}  // anonymous namespace

namespace megdnn {
namespace cuda {
namespace resize {

template <typename ctype>
290 291 292 293
void forward_proxy(bool is_nhwc, InterpolationMode imode, const ctype* src,
                   ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
                   int S_IN, int S_IC, int S_IH, int S_IW,
                   cudaStream_t stream) {
294 295
    DirectSrcVisitor<ctype> visitor;
    visitor.ptr = src;
296 297
    dispatch_with_visitor(is_nhwc, imode, visitor, dst, N, C, IH, IW, OH, OW,
                          S_IN, S_IC, S_IH, S_IW, stream);
298 299 300 301 302 303 304 305 306 307 308 309
    after_kernel_launch();
}

template <typename ctype>
void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
                         int IW, int OH, int OW, cudaStream_t stream) {
    DirectSrcVisitor<ctype> visitor;
    visitor.ptr = src;
    dispatch_with_visitor_nchw4(visitor, dst, N, C, IH, IW, OH, OW, stream);
    after_kernel_launch();
}

310 311 312 313
#define INST(ctype)                                                            \
    template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, \
                                int, int, int, int, int, int, int, int, int,   \
                                int, cudaStream_t);
314 315 316 317 318
INST(float)
INST(uint8_t)
INST(int8_t)
#undef INST

319
#define INST(ctype)                                                        \
320
    template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \
321
                                      int, int, int, cudaStream_t)
322 323 324 325 326 327 328 329

INST(int8_t);
#undef INST
}  // namespace resize
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen