backward.cu 3.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/resize/backward.cu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"

#include "src/cuda/utils.cuh"

namespace megdnn {
namespace cuda {
namespace resize {

20 21 22
__global__ void resize_bwd_linear_kernel(const float* hidden, float* dst, int N,
                                         int C, int IH, int IW, int OH, int OW,
                                         float scale_h, float scale_w) {
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        float alphah, alphaw;
        int ih0, iw0;
        get_origin_coord(scale_h, IH, oh, alphah, ih0);
        get_origin_coord(scale_w, IW, ow, alphaw, iw0);

        int ih1 = ih0 + 1;
        int iw1 = iw0 + 1;

        float nalphaw = 1.0f - alphaw;
        float nalphah = 1.0f - alphah;
        for (int c = 0; c < C; ++c) {
            atomicAdd(dst + ih0 * IW + iw0,
                      hidden[oh * OW + ow] * nalphaw * nalphah);
            atomicAdd(dst + ih0 * IW + iw1,
                      hidden[oh * OW + ow] * alphaw * nalphah);
            atomicAdd(dst + ih1 * IW + iw0,
                      hidden[oh * OW + ow] * nalphaw * alphah);
            atomicAdd(dst + ih1 * IW + iw1,
                      hidden[oh * OW + ow] * alphaw * alphah);
            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
__global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst,
                                          int N, int C, int IH, int IW, int OH,
                                          int OW, float scale_h,
                                          float scale_w) {
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C * OH * OW;
    dst += n * C * IH * IW;
    if (ow < OW && oh < OH) {
        int ih = get_nearest_src(scale_h, IH, oh);
        int iw = get_nearest_src(scale_w, IW, ow);

        for (int c = 0; c < C; ++c) {
            atomicAdd(dst + ih * IW + iw,
                      hidden[oh * OW + ow]);
            hidden += OH * OW;
            dst += IH * IW;
        }
    }
}
void backward_data_proxy(InterpolationMode imode, const float* diff,
                         float* grad, int N, int C, int IH, int IW, int OH,
                         int OW, cudaStream_t stream) {
78 79 80 81 82 83 84 85
    const int BY = 16, BX = 32;
    {
        dim3 threads(BX, BY);
        dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N);
        cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW,
                                   stream));
        float scale_h = static_cast<float>(OH) / IH;
        float scale_w = static_cast<float>(OW) / IW;
86 87 88 89 90 91 92 93
        if(imode == InterpolationMode::INTER_LINEAR) {
            resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
                    diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
        }
        else if (imode == InterpolationMode::INTER_NEAREST) {
            resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
                    diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
        }
94 95 96 97 98 99 100 101 102
    }
    after_kernel_launch();
}

}  // namespace resize
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen