backward_data.cu 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
/**
 * \file dnn/src/cuda/warp_perspective/backward_data.cu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/cuda/warp_perspective/common.h"

#include "src/cuda/utils.cuh"
#include "src/cuda/warp_perspective/common.cuh"

namespace megdnn {
namespace cuda {
namespace warp_perspective {

const int factor = 4;

template <typename Getter, int factor>
__global__ void warp_perspective_bwd_data_kernel(const float *hidden,
        const float *mat, float *dst,
        int N, int C, int IH, int IW, int OH, int OW)
{
    Getter getter;
    int n = blockIdx.z;
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += n * C*OH*OW;
    dst += n * C*factor*IH*IW;
    mat += n * 3*3;
    if (ow < OW && oh < OH) {
        float denominator = mat[6]*ow + mat[7]*oh + mat[8];
        float iw = (mat[0]*ow + mat[1]*oh + mat[2]) / denominator;
        float ih = (mat[3]*ow + mat[4]*oh + mat[5]) / denominator;
        int iw0 = getter(floor(iw) + 0, IW);
        int iw1 = getter(floor(iw) + 1, IW);
        int ih0 = getter(floor(ih) + 0, IH);
        int ih1 = getter(floor(ih) + 1, IH);
        float palpha = ih - floor(ih);
        float pbeta = iw - floor(iw);
        float nalpha = 1.0f - palpha;
        float nbeta = 1.0f - pbeta;
        int i = ow & (factor-1);
        for (int c = 0; c < C; ++c) {
            atomicAdd(dst + ih0*IW+iw0 + i*IH*IW, hidden[oh*OW+ow]*nalpha*nbeta);
            atomicAdd(dst + ih0*IW+iw1 + i*IH*IW, hidden[oh*OW+ow]*nalpha*pbeta);
            atomicAdd(dst + ih1*IW+iw0 + i*IH*IW, hidden[oh*OW+ow]*palpha*nbeta);
            atomicAdd(dst + ih1*IW+iw1 + i*IH*IW, hidden[oh*OW+ow]*palpha*pbeta);
            hidden += OH*OW;
            dst += factor*IH*IW;
        }
    }
}

template <int factor>
__global__ void add_up_kernel(const float *src, float *dst,
        int IP)
{
    int nc = blockIdx.y;
    int ip = blockIdx.x * blockDim.x + threadIdx.x;
    src += nc*IP*factor;
    dst += nc*IP;
    if (ip < IP) {
        dst[ip] = src[ip];
#pragma unroll
        for (int i = 1; i < factor; ++i)
            dst[ip] += src[ip+i*IP];
    }
}

template <int factor>
__global__ void warp_perspective_bwd_data_constant_kernel(const float *hidden,
        const float *mat, float *dst,
        int N, int C, int IH, int IW, int OH, int OW)
{
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += blockIdx.z * C*OH*OW;
    dst += blockIdx.z * C*factor*IH*IW;
    mat += blockIdx.z * 3*3;
    if (ow < OW && oh < OH) {
        float denominator = mat[6]*ow + mat[7]*oh + mat[8];
        float iw = (mat[0]*ow + mat[1]*oh + mat[2]) / denominator;
        float ih = (mat[3]*ow + mat[4]*oh + mat[5]) / denominator;
        int iw0 = floor(iw) + 0;
        int iw1 = floor(iw) + 1;
        int ih0 = floor(ih) + 0;
        int ih1 = floor(ih) + 1;
        bool okw0 = (iw0 >= 0 && iw0 < IW);
        bool okw1 = (iw1 >= 0 && iw1 < IW);
        bool okh0 = (ih0 >= 0 && ih0 < IH);
        bool okh1 = (ih1 >= 0 && ih1 < IH);
        float palpha = ih - floor(ih);
        float pbeta = iw - floor(iw);
        float nalpha = 1.0f - palpha;
        float nbeta = 1.0f - pbeta;
        int i = ow & (factor-1);
        if (isfinite(ih) && isfinite(iw)) {
            for (int c = 0; c < C; ++c) {
                if (okh0 && okw0)
                    atomicAdd(dst + ih0*IW+iw0 + i*IH*IW,
                            hidden[oh*OW+ow]*nalpha*nbeta);
                if (okh0 && okw1)
                    atomicAdd(dst + ih0*IW+iw1 + i*IH*IW,
                            hidden[oh*OW+ow]*nalpha*pbeta);
                if (okh1 && okw0)
                    atomicAdd(dst + ih1*IW+iw0 + i*IH*IW,
                            hidden[oh*OW+ow]*palpha*nbeta);
                if (okh1 && okw1)
                    atomicAdd(dst + ih1*IW+iw1 + i*IH*IW,
                            hidden[oh*OW+ow]*palpha*pbeta);
                hidden += OH*OW;
                dst += factor*IH*IW;
            }
        }
    }
}

size_t get_backward_data_workspace_in_bytes(
        int N, int C, int IH, int IW, int /* OH */, int /* OW */,
        BorderMode /* bmode */)
{
    return N*C*IH*IW*factor * sizeof(float);
}

void backward_data_proxy(const float *mat, const float *diff,
        float *grad, float *workspace,
        int N, int C, int IH, int IW, int OH, int OW, float bval,
        BorderMode mode, cudaStream_t stream)
{

    (void)bval;
    (void)grad;
    const int BY = 16, BX = 32;
    {
        dim3 threads(BX, BY);
        dim3 blocks((OW+BX-1)/BX, (OH+BY-1)/BY, N);
        cuda_check(cudaMemsetAsync(workspace, 0, sizeof(float) * factor*N*C*IH*IW,
                    stream));
#define DISPATCH(Getter) \
        warp_perspective_bwd_data_kernel<Getter, factor><<<blocks, threads, \
            0, stream>>>(diff, mat, workspace, N, C, IH, IW, OH, OW);
        switch (mode) {
            case BORDER_REPLICATE:
                DISPATCH(ReplicateGetter);
                break;
            case BORDER_REFLECT:
                DISPATCH(ReflectGetter);
                break;
            case BORDER_REFLECT_101:
                DISPATCH(Reflect101Getter);
                break;
            case BORDER_WRAP:
                DISPATCH(WrapGetter);
                break;
            case BORDER_CONSTANT:
                warp_perspective_bwd_data_constant_kernel<factor>
                    <<<blocks, threads, 0, stream>>>
                    (diff, mat, workspace, N, C, IH, IW, OH, OW);
                break;
            default:
                break;
        }
#undef DISPATCH
    }
    {
        int THREADS = 512;
        dim3 threads(THREADS);
        dim3 blocks((IH*IW+THREADS-1)/THREADS, N*C);
        add_up_kernel<factor><<<blocks, threads, 0, stream>>>(workspace, grad,
                IH*IW);
    }
    after_kernel_launch();
}

} // namespace warp_perspective
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen