backward_mat.cu 10.2 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/cuda/warp_perspective/backward_mat.cu
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23
 */
#include "src/cuda/warp_perspective/common.h"

#include "src/cuda/utils.cuh"
#include "src/cuda/warp_perspective/common.cuh"
#include <cstdio>
#include "src/cuda/cub/util_ptx.cuh"

namespace megdnn {
namespace cuda {
namespace warp_perspective {

template <typename Getter>
24 25 26
__global__ void warp_perspective_bwd_mat_kernel(
        const float* hidden, const float* in, const float* mat, const int* midx,
        float* grad, int N, int C, int IH, int IW, int OH, int OW) {
27
    Getter getter;
28
    int n = blockIdx.z;
29 30 31
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
    hidden += blockIdx.z * C*OH*OW;
32 33 34 35 36 37 38
    if (midx) {
        in += midx[n] * C * IH * IW;
    } else {
        in += n * C * IH * IW;
    }
    mat += n * 3*3;
    grad += n * 3*3;
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    float grad_local[3*3];
    memset(grad_local, 0, sizeof(grad_local));
    if (ow < OW && oh < OH) {
        float numeratorw = mat[0]*ow + mat[1]*oh + mat[2];
        float numeratorh = mat[3]*ow + mat[4]*oh + mat[5];
        float denominator = mat[6]*ow + mat[7]*oh + mat[8];
        float denominator2 = sqr(denominator);
        float iw = numeratorw / denominator;
        float ih = numeratorh / denominator;
        int iw0 = getter(floor(iw) + 0, IW);
        int iw1 = getter(floor(iw) + 1, IW);
        int ih0 = getter(floor(ih) + 0, IH);
        int ih1 = getter(floor(ih) + 1, IH);
        float palpha = ih - floor(ih);
        float pbeta = iw - floor(iw);
        float nalpha = 1.0f - palpha;
        float nbeta = 1.0f - pbeta;
        for (int c = 0; c < C; ++c) {
            float dalpha = 0, dbeta = 0;
            dalpha -= in[ih0*IW+iw0] * nbeta;
            dalpha -= in[ih0*IW+iw1] * pbeta;
            dalpha += in[ih1*IW+iw0] * nbeta;
            dalpha += in[ih1*IW+iw1] * pbeta;
            dbeta -= in[ih0*IW+iw0] * nalpha;
            dbeta += in[ih0*IW+iw1] * nalpha;
            dbeta -= in[ih1*IW+iw0] * palpha;
            dbeta += in[ih1*IW+iw1] * palpha;
            float dw[9], dh[9];
            // dw[i] = d(iw)/d(mat[i])
            dw[0] = ow / denominator;
            dw[1] = oh / denominator;
            dw[2] = 1.0f / denominator;
            dw[3] = 0.0f;
            dw[4] = 0.0f;
            dw[5] = 0.0f;
            float ddenominatorw = -numeratorw / denominator2;
            dw[6] = ow * ddenominatorw;
            dw[7] = oh * ddenominatorw;
            dw[8] = 1.0f * ddenominatorw;
            // dh[i] = d(ih)/d(mat[i])
            dh[0] = 0.0f;
            dh[1] = 0.0f;
            dh[2] = 0.0f;
            dh[3] = ow / denominator;
            dh[4] = oh / denominator;
            dh[5] = 1.0f / denominator;
            float ddenominatorh = -numeratorh / denominator2;
            dh[6] = ow * ddenominatorh;
            dh[7] = oh * ddenominatorh;
            dh[8] = 1.0f * ddenominatorh;
#pragma unroll
            for (int i = 0; i < 9; ++i) {
91 92
                grad_local[i] += hidden[oh * OW + ow] * dalpha * dh[i] +
                                 hidden[oh * OW + ow] * dbeta * dw[i];
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            }
            hidden += OH*OW;
            in += IH*IW;
        }
    }
    volatile __shared__ float grad_shared[16][32][3*3];
    int tidy = threadIdx.y, tidx = threadIdx.x;
#pragma unroll
    for (int i = 0; i < 9; ++i)
        grad_shared[tidy][tidx][i] = grad_local[i];
    __syncthreads();
    for (int k = 8; k >= 1; k >>= 1) {
        if (tidy < k) {
#pragma unroll
            for (int i = 0; i < 9; ++i) {
                grad_shared[tidy][tidx][i] += grad_shared[tidy+k][tidx][i];
            }
        }
        __syncthreads();
    }
    if (tidy == 0 && tidx < 16) {
        for (int k = 16; k >= 1; k >>= 1) {
            if (tidx < k) {
#pragma unroll
                for (int i = 0; i < 9; ++i) {
                    grad_shared[tidy][tidx][i] +=
                            grad_shared[tidy][tidx + k][i];
                }
            }
            cub::WARP_SYNC(0xffffffff);
        }
    }
    if (tidy == 0 && tidx == 0) {
#pragma unroll
        for (int i = 0; i < 9; ++i)
            atomicAdd(grad+i, grad_shared[0][0][i]);
    }
}

132 133 134 135
__global__ void warp_perspective_bwd_mat_constant_kernel(
        const float* hidden, const float* in, const float* mat, const int* midx,
        float* grad, int N, int C, int IH, int IW, int OH, int OW, float bval) {
    int n = blockIdx.z;
136 137
    int ow = blockIdx.x * blockDim.x + threadIdx.x;
    int oh = blockIdx.y * blockDim.y + threadIdx.y;
138 139 140 141 142 143 144 145 146
    hidden += blockIdx.z * C * OH * OW;
    if (midx) {
        in += midx[n] * C * IH * IW;
    } else {
        in += n * C * IH * IW;
    }
    mat += n * 3 * 3;
    grad += n * 3 * 3;
    float grad_local[3 * 3];
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    memset(grad_local, 0, sizeof(grad_local));
    if (ow < OW && oh < OH) {
        float numeratorw = mat[0]*ow + mat[1]*oh + mat[2];
        float numeratorh = mat[3]*ow + mat[4]*oh + mat[5];
        float denominator = mat[6]*ow + mat[7]*oh + mat[8];
        float denominator2 = sqr(denominator);
        float iw = numeratorw / denominator;
        float ih = numeratorh / denominator;
        int iw0 = floor(iw) + 0;
        int iw1 = floor(iw) + 1;
        int ih0 = floor(ih) + 0;
        int ih1 = floor(ih) + 1;
        bool okw0 = (iw0 >= 0 && iw0 < IW);
        bool okw1 = (iw1 >= 0 && iw1 < IW);
        bool okh0 = (ih0 >= 0 && ih0 < IH);
        bool okh1 = (ih1 >= 0 && ih1 < IH);
        iw0 = min(max(iw0, 0), IW-1);
        iw1 = min(max(iw1, 0), IW-1);
        ih0 = min(max(ih0, 0), IH-1);
        ih1 = min(max(ih1, 0), IH-1);
        float palpha = ih - floor(ih);
        float pbeta = iw - floor(iw);
        float nalpha = 1.0f - palpha;
        float nbeta = 1.0f - pbeta;
        for (int c = 0; c < C; ++c) {
            float v00 = (okh0 && okw0 ? in[ih0*IW+iw0] : bval);
            float v01 = (okh0 && okw1 ? in[ih0*IW+iw1] : bval);
            float v10 = (okh1 && okw0 ? in[ih1*IW+iw0] : bval);
            float v11 = (okh1 && okw1 ? in[ih1*IW+iw1] : bval);
            float dalpha = 0, dbeta = 0;
            dalpha -= v00 * nbeta;
            dalpha -= v01 * pbeta;
            dalpha += v10 * nbeta;
            dalpha += v11 * pbeta;
            dbeta -= v00 * nalpha;
            dbeta += v01 * nalpha;
            dbeta -= v10 * palpha;
            dbeta += v11 * palpha;
            float dw[9], dh[9];
            // dw[i] = d(iw)/d(mat[i])
            dw[0] = ow / denominator;
            dw[1] = oh / denominator;
            dw[2] = 1.0f / denominator;
            dw[3] = 0.0f;
            dw[4] = 0.0f;
            dw[5] = 0.0f;
            float ddenominatorw = -numeratorw / denominator2;
            dw[6] = ow * ddenominatorw;
            dw[7] = oh * ddenominatorw;
            dw[8] = 1.0f * ddenominatorw;
            // dh[i] = d(ih)/d(mat[i])
            dh[0] = 0.0f;
            dh[1] = 0.0f;
            dh[2] = 0.0f;
            dh[3] = ow / denominator;
            dh[4] = oh / denominator;
            dh[5] = 1.0f / denominator;
            float ddenominatorh = -numeratorh / denominator2;
            dh[6] = ow * ddenominatorh;
            dh[7] = oh * ddenominatorh;
            dh[8] = 1.0f * ddenominatorh;
#pragma unroll
            for (int i = 0; i < 9; ++i) {
210 211 212 213
                float delta = hidden[oh * OW + ow] * dalpha * dh[i] +
                              hidden[oh * OW + ow] * dbeta * dw[i];
                if (isfinite(delta))
                    grad_local[i] += delta;
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
            }
            hidden += OH*OW;
            in += IH*IW;
        }
    }
    volatile __shared__ float grad_shared[16][32][3*3];
    int tidy = threadIdx.y, tidx = threadIdx.x;
#pragma unroll
    for (int i = 0; i < 9; ++i)
        grad_shared[tidy][tidx][i] = grad_local[i];
    __syncthreads();
    for (int k = 8; k >= 1; k >>= 1) {
        if (tidy < k) {
#pragma unroll
            for (int i = 0; i < 9; ++i) {
                grad_shared[tidy][tidx][i] += grad_shared[tidy+k][tidx][i];
            }
        }
        __syncthreads();
    }
    if (tidy == 0 && tidx < 16) {
        for (int k = 16; k >= 1; k >>= 1) {
            if (tidx < k) {
#pragma unroll
238 239 240
                for (int i = 0; i < 9; ++i)
                    grad_shared[tidy][tidx][i] +=
                            grad_shared[tidy][tidx + k][i];
241 242 243 244 245 246 247 248 249 250 251
            }
            cub::WARP_SYNC(0xffffffff);
        }
    }
    if (tidy == 0 && tidx == 0) {
#pragma unroll
        for (int i = 0; i < 9; ++i)
            atomicAdd(grad+i, grad_shared[0][0][i]);
    }
}

252 253 254 255
void backward_mat_proxy(const float* src, const float* mat, const int* midx,
                        const float* diff, float* grad, int N, int C, int IH,
                        int IW, int OH, int OW, float bval, BorderMode mode,
                        cudaStream_t stream) {
256 257 258 259
    const int BY = 16, BX = 32;
    dim3 threads(BX, BY);
    dim3 blocks((OW+BX-1)/BX, (OH+BY-1)/BY, N);
    cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N*3*3, stream));
260
#define DISPATCH(Getter)                                                     \
261
    warp_perspective_bwd_mat_kernel<Getter><<<blocks, threads, 0, stream>>>( \
262
            diff, src, mat, midx, grad, N, C, IH, IW, OH, OW);
263 264 265 266 267 268 269 270 271 272 273 274 275 276
    switch (mode) {
        case BORDER_REPLICATE:
            DISPATCH(ReplicateGetter);
            break;
        case BORDER_REFLECT:
            DISPATCH(ReflectGetter);
            break;
        case BORDER_REFLECT_101:
            DISPATCH(Reflect101Getter);
            break;
        case BORDER_WRAP:
            DISPATCH(WrapGetter);
            break;
        case BORDER_CONSTANT:
277 278 279
            warp_perspective_bwd_mat_constant_kernel<<<blocks, threads, 0,
                                                       stream>>>(
                    diff, src, mat, midx, grad, N, C, IH, IW, OH, OW, bval);
280 281 282 283 284 285 286 287 288 289 290 291 292
            break;
        default:
            break;
    }
#undef DISPATCH
    after_kernel_launch();
}

} // namespace warp_perspective
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen