backward.cu 7.9 KB
Newer Older
1
#include "src/common/rounding_converter.cuh"
2 3 4
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"

5
#include "src/cuda/cv/kernel_common.cuh"
M
Megvii Engine Team 已提交
6
#include "src/cuda/utils.cuh"
7 8

using megdnn::megcv::saturate;
M
Megvii Engine Team 已提交
9
using megdnn::resize::interpolate_cubic;
10 11 12 13 14

namespace megdnn {
namespace cuda {
namespace resize {

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_nhwc_kernel(
        const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
        float scale_h, float scale_w) {
    OutputConverter output_converter;
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        float nalphaw = 1.0f - alphaw;
        float nalphah = 1.0f - alphah;
        for (int c = 0; c < C; ++c) {
            atomic_add(
                    dst + (ih0 * IW + iw0) * C + c,
                    output_converter(
                            hidden[(oh * OW + ow) * C + c] * nalphaw * nalphah));
            atomic_add(
                    dst + (ih0 * IW + iw1) * C + c,
                    output_converter(
                            hidden[(oh * OW + ow) * C + c] * alphaw * nalphah));
            atomic_add(
                    dst + (ih1 * IW + iw0) * C + c,
                    output_converter(
                            hidden[(oh * OW + ow) * C + c] * nalphaw * alphah));
            atomic_add(
                    dst + (ih1 * IW + iw1) * C + c,
                    output_converter(hidden[(oh * OW + ow) * C + c] * alphaw * alphah));
        }
    }
}

template <typename ctype, typename OutputConverter>
M
Megvii Engine Team 已提交
57
__global__ void resize_bwd_linear_kernel(
58
        const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
M
Megvii Engine Team 已提交
59
        float scale_h, float scale_w) {
60
    OutputConverter output_converter;
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        float nalphaw = 1.0f - alphaw;
        float nalphah = 1.0f - alphah;
        for (int c = 0; c < C; ++c) {
78 79 80 81 82 83 84 85 86 87 88 89
            atomic_add(
                    dst + ih0 * IW + iw0,
                    output_converter(hidden[oh * OW + ow] * nalphaw * nalphah));
            atomic_add(
                    dst + ih0 * IW + iw1,
                    output_converter(hidden[oh * OW + ow] * alphaw * nalphah));
            atomic_add(
                    dst + ih1 * IW + iw0,
                    output_converter(hidden[oh * OW + ow] * nalphaw * alphah));
            atomic_add(
                    dst + ih1 * IW + iw1,
                    output_converter(hidden[oh * OW + ow] * alphaw * alphah));
90 91 92 93 94 95
            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}

96
template <typename ctype, typename OutputConverter>
M
Megvii Engine Team 已提交
97
__global__ void resize_bwd_nearest_kernel(
98
        const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
M
Megvii Engine Team 已提交
99
        float scale_h, float scale_w) {
100
    OutputConverter output_converter;
101 102 103 104 105 106 107 108 109 110
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        int ih = get_nearest_src(scale_h, IH, oh);
        int iw = get_nearest_src(scale_w, IW, ow);

        for (int c = 0; c < C; ++c) {
111
            atomic_add(dst + ih * IW + iw, output_converter(hidden[oh * OW + ow]));
112 113 114 115 116
            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}
117

118
template <typename ctype, typename OutputConverter>
M
Megvii Engine Team 已提交
119
__global__ void resize_bwd_cubic_kernel(
120
        const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
M
Megvii Engine Team 已提交
121
        float scale_h, float scale_w) {
122
    OutputConverter output_converter;
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
        ih0--;
        iw0--;
        float h_coeff[4], w_coeff[4];
        interpolate_cubic(alphah, h_coeff);
        interpolate_cubic(alphaw, w_coeff);
        for (int c = 0; c < C; ++c) {
            constexpr int ksize = 4;
            for (int kh = 0; kh < ksize; kh++) {
                int ih = saturate(ih0 + kh, 0, IH - 1);
                for (int kw = 0; kw < ksize; kw++) {
                    int iw = saturate(iw0 + kw, 0, IW - 1);
144
                    atomic_add(
M
Megvii Engine Team 已提交
145
                            dst + ih * IW + iw,
146 147
                            output_converter(
                                    hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]));
148 149 150 151 152 153 154 155 156
                }
            }

            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}

157
template <typename ctype>
M
Megvii Engine Team 已提交
158
void backward_data_proxy(
159 160
        bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
        int C, int IH, int IW, int OH, int OW, cudaStream_t stream) {
161 162 163 164
    const int BY = 16, BX = 32;
    {
        dim3 threads(BX, BY);
        dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N);
165
        cuda_check(cudaMemsetAsync(grad, 0, sizeof(ctype) * N * C * IH * IW, stream));
166 167
        float scale_h = static_cast<float>(OH) / IH;
        float scale_w = static_cast<float>(OW) / IW;
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
        if (is_nhwc) {
            resize_bwd_nhwc_kernel<ctype, rounding::RoundingConverter<ctype>>
                    <<<blocks, threads, 0, stream>>>(
                            diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
        } else {
            switch (imode) {
                case InterpolationMode::INTER_LINEAR: {
                    resize_bwd_linear_kernel<ctype, rounding::RoundingConverter<ctype>>
                            <<<blocks, threads, 0, stream>>>(
                                    diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
                    break;
                }
                case InterpolationMode::INTER_NEAREST: {
                    resize_bwd_nearest_kernel<ctype, rounding::RoundingConverter<ctype>>
                            <<<blocks, threads, 0, stream>>>(
                                    diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
                    break;
                }
                case InterpolationMode::INTER_CUBIC: {
                    resize_bwd_cubic_kernel<ctype, rounding::RoundingConverter<ctype>>
                            <<<blocks, threads, 0, stream>>>(
                                    diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
                    break;
                }
                default: {
                    megdnn_throw("unsupported interpolation mode");
                    break;
                }
196
            }
197
        }
198 199 200 201
    }
    after_kernel_launch();
}

202 203 204 205 206 207 208 209
#define INST(ctype)                                                                 \
    template void backward_data_proxy(                                              \
            bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
            int, cudaStream_t);
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST

210 211 212 213 214
}  // namespace resize
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen