backward.cu 5.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/resize/backward.cu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"

#include "src/cuda/utils.cuh"
15 16 17 18
#include "src/cuda/cv/kernel_common.cuh"

using megdnn::resize::interpolate_cubic;
using megdnn::megcv::saturate;
19 20 21 22 23

namespace megdnn {
namespace cuda {
namespace resize {

24 25 26
__global__ void resize_bwd_linear_kernel(const float* hidden, float* dst, int N,
                                         int C, int IH, int IW, int OH, int OW,
                                         float scale_h, float scale_w) {
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        float nalphaw = 1.0f - alphaw;
        float nalphah = 1.0f - alphah;
        for (int c = 0; c < C; ++c) {
            atomicAdd(dst + ih0 * IW + iw0,
                      hidden[oh * OW + ow] * nalphaw * nalphah);
            atomicAdd(dst + ih0 * IW + iw1,
                      hidden[oh * OW + ow] * alphaw * nalphah);
            atomicAdd(dst + ih1 * IW + iw0,
                      hidden[oh * OW + ow] * nalphaw * alphah);
            atomicAdd(dst + ih1 * IW + iw1,
                      hidden[oh * OW + ow] * alphaw * alphah);
            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
__global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst,
                                          int N, int C, int IH, int IW, int OH,
                                          int OW, float scale_h,
                                          float scale_w) {
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        int ih = get_nearest_src(scale_h, IH, oh);
        int iw = get_nearest_src(scale_w, IW, ow);

        for (int c = 0; c < C; ++c) {
            atomicAdd(dst + ih * IW + iw,
                      hidden[oh * OW + ow]);
            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114

__global__ void resize_bwd_cubic_kernel(const float* hidden, float* dst, int N,
                                        int C, int IH, int IW, int OH, int OW,
                                        float scale_h, float scale_w) {
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
        ih0--;
        iw0--;
        float h_coeff[4], w_coeff[4];
        interpolate_cubic(alphah, h_coeff);
        interpolate_cubic(alphaw, w_coeff);
        for (int c = 0; c < C; ++c) {
            constexpr int ksize = 4;
            for (int kh = 0; kh < ksize; kh++) {
                int ih = saturate(ih0 + kh, 0, IH - 1);
                for (int kw = 0; kw < ksize; kw++) {
                    int iw = saturate(iw0 + kw, 0, IW - 1);
                    atomicAdd(dst + ih * IW + iw,
                              hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]);
                }
            }

            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}

115 116 117
void backward_data_proxy(InterpolationMode imode, const float* diff,
                         float* grad, int N, int C, int IH, int IW, int OH,
                         int OW, cudaStream_t stream) {
118 119 120 121 122 123 124 125
    const int BY = 16, BX = 32;
    {
        dim3 threads(BX, BY);
        dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N);
        cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW,
                                   stream));
        float scale_h = static_cast<float>(OH) / IH;
        float scale_w = static_cast<float>(OW) / IW;
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        switch (imode) {
            case InterpolationMode::INTER_LINEAR: {
                resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
                        diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
                break;
            }
            case InterpolationMode::INTER_NEAREST: {
                resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
                        diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
                break;
            }
            case InterpolationMode::INTER_CUBIC: {
                resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>(
                        diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
                break;
            }
            default: {
                megdnn_throw("unsupported interpolation mode");
                break;
            }
146
        }
147 148 149 150 151 152 153 154 155
    }
    after_kernel_launch();
}

}  // namespace resize
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen