hl_cuda_cnn.cu 27.1 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include <float.h>
#include "hl_base.h"
#include "hl_cnn.h"

__global__ void KeFeature2col(size_t n, size_t height, const real* data_im,
                              size_t blockH, size_t blockW, size_t width,
                              size_t strideH, size_t strideW,
                              size_t paddingH, size_t paddingW,
                              size_t height_col, size_t width_col,
                              real* data_col) {
  size_t index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    size_t w_out = index % width_col;
    index /= width_col;
    size_t h_out = index % height_col;
    size_t channel_in = index / height_col;
    size_t channel_out = channel_in * blockH * blockW;
    size_t h_in = h_out * strideH;
    size_t w_in = w_out * strideW;

    data_col += (channel_out * height_col + h_out) * width_col + w_out;
    for (size_t i = 0; i < blockH; ++i) {
      for (size_t j = 0; j < blockW; ++j) {
        int rIdx = int(h_in+i);
        int cIdx = int(w_in+j);
        if ((rIdx-(int)paddingH) >= (int)height ||
            (rIdx-(int)paddingH) < 0 ||
            (cIdx-(int)paddingW) >= (int)width ||
            (cIdx-(int)paddingW) < 0) {
          *data_col = 0;
        } else {
          rIdx = rIdx + channel_in*height - paddingH;
          cIdx = cIdx - paddingW;
          *data_col = data_im[rIdx* width + cIdx];
        }
        data_col += height_col * width_col;
      }
    }
  }
}

void hl_expand_feature2col(const real* dataIm, size_t channels,
                           size_t height, size_t width,
                           size_t blockH, size_t blockW,
                           size_t strideH, size_t strideW,
                           size_t paddingH, size_t paddingW,
                           size_t outputH, size_t outputW,
                           real* dataCol) {
  size_t numKernels = channels * outputH * outputW;

  size_t blocks = (numKernels + 1024 -1) / 1024;
  size_t blockX = 512;
  size_t blockY = (blocks+512-1)/512;
  dim3 threads(1024, 1);
  dim3 grid(blockX, blockY);
  KeFeature2col<<< grid, threads, 0, STREAM_DEFAULT >>>
           (numKernels, height, dataIm, blockH, blockW, width,
           strideH, strideW, paddingH, paddingW,
           outputH, outputW, dataCol);
  CHECK_SYNC("hl_expand_feature2col failed");
}

__global__ void KeCol2Feature(size_t n, const real* data_col, size_t height,
                              size_t width, size_t channels,
                              size_t blockH, size_t blockW,
                              size_t strideH, size_t strideW,
                              size_t paddingH, size_t paddingW,
                              size_t height_col, size_t width_col,
                              real* data_im, real alpha, real beta) {
  size_t index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    real val = 0;
    int w = int(index % width);
    int h = int((index / width) % height);
    int c = int(index / (width * height));
    if ((w - (int)paddingW) >= 0 &&
        (w - (int)paddingW) < (width-2 * paddingW) &&
        (h - (int)paddingH) >= 0 &&
        (h - paddingH) < (height - 2 * paddingH)) {
      // compute the start and end of the output
      int w_col_start =
        (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
      int w_col_end =
        min((int)(w / (int)strideW + 1), (int)(width_col));
      int h_col_start =
        (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
      int h_col_end = min(int(h / strideH + 1), int(height_col));
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
          // the col location: [c * width * height + h_out, w_out]
          int c_col = int(c * blockH* blockW) + \
            (h - h_col * (int)strideH) * (int)blockW +
            (w - w_col * (int)strideW);
          val += data_col[(c_col * height_col + h_col) * width_col + w_col];
        }
      }
      h -= paddingH;
      w -= paddingW;
      real tD = data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
                          h*(width-2*paddingW) + w];
      data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
              h*(width-2*paddingW) + w] = alpha * val + beta*tD;
    }
  }
}

void hl_shrink_col2feature(const real * dataCol, size_t channels,
                           size_t height, size_t width,
                           size_t blockH, size_t blockW,
                           size_t strideH, size_t strideW,
                           size_t paddingH, size_t paddingW,
                           size_t outputH, size_t outputW,
                           real* dataIm, real alpha, real beta) {
  size_t numKernels = channels * (height + 2*paddingH) * (width + 2*paddingW);

  size_t blocks = (numKernels + 1024 -1) / 1024;
  size_t blockX = 512;
  size_t blockY = (blocks+512-1)/512;
  dim3 threads(1024, 1);
  dim3 grid(blockX, blockY);

  // To avoid involving atomic operations, we will launch one kernel per
  // bottom dimension, and then in the kernel add up the top dimensions.
  KeCol2Feature<<< grid, threads, 0, STREAM_DEFAULT >>>
           (numKernels, dataCol, height + 2*paddingH, width + 2*paddingW,
           channels, blockH, blockW, strideH, strideW, paddingH, paddingW,
           outputH, outputW, dataIm, alpha, beta);
  CHECK_SYNC("hl_shrink_col2feature failed");
}

148 149 150 151 152 153 154
__global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
                                 const int channels, const int height,
                                 const int width,
                                 const int pooledH, const int pooledW,
                                 const int ksizeW, const int ksizeH,
                                 const int strideH, const int strideW,
                                 const int offsetH, const int offsetW,
Z
zhangjinchao01 已提交
155
                                 real* tgtData) {
156
  int index =  blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
157 158 159 160
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
161 162 163 164 165 166 167
    int frameNum = index / pooledW / pooledH / channels;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
Z
zhangjinchao01 已提交
168 169 170 171 172 173 174 175 176 177 178 179
    real maxval = -FLT_MAX;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        if (maxval < inputData[h * width + w])
          maxval = inputData[h * width + w];
      }
    }
    tgtData[index] = maxval;
  }
}

180 181 182 183 184 185 186 187 188 189 190
void hl_maxpool_forward(const int frameCnt, const real* inputData,
                        const int channels,
                        const int height, const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW,
                        real* tgtData) {

  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
191
  dim3 threads(1024, 1);
192 193
  dim3 grid(blocks, 1);

Z
zhangjinchao01 已提交
194 195
  KeMaxPoolForward<<< grid, threads, 0, STREAM_DEFAULT >>>
           (num_kernels, inputData, channels, height, width,
196 197
           pooledH, pooledW, sizeX, sizeY, strideH, strideW,
           paddingH, paddingW, tgtData);
Z
zhangjinchao01 已提交
198 199 200
  CHECK_SYNC("hl_maxpool_forward failed");
}

201
__global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
Z
zhangjinchao01 已提交
202
                                  const real* outData, const real* outGrad,
203 204 205 206 207 208 209 210 211
                                  const int channels, const int height,
                                  const int width,
                                  const int pooledH, const int pooledW,
                                  const int sizeX, const int sizeY,
                                  const int strideH, const int strideW,
                                  const int padH, const int padW,
                                  real scaleA, real scaleB,
                                  real* targetGrad) {
  int index = blockIdx.x  * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
212 213 214
  if (index < nthreads) {
    // find out the local index
    // find out the local offset
215 216
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
217
    int offsetC = (index / width / height) % channels;
218 219 220 221 222 223

    int frameNum = index / width / height / channels;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    real gradient = 0;
    real input = inputData[index];
    outData += (frameNum * channels + offsetC) * pooledH * pooledW;
    outGrad += (frameNum * channels + offsetC) * pooledH * pooledW;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        if (input == outData[ph * pooledW + pw]) {
          gradient += outGrad[ph * pooledW + pw];
        }
      }
    }
    targetGrad[index] =
      scaleB * targetGrad[index] + scaleA * gradient;
  }
}

240
void hl_maxpool_backward(const int frameCnt, const real* inputData,
Z
zhangjinchao01 已提交
241
                        const real* outData, const real* outGrad,
242 243 244 245 246 247 248 249 250 251 252
                        const int channels, const int height,
                        const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW,
                        real scaleA, real scaleB,
                        real* targetGrad) {

  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
253

254
  KeMaxPoolBackward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
255
           (num_kernels, inputData, outData, outGrad, channels,
256 257 258 259 260
           height, width, pooledH, pooledW, sizeX, sizeY,
           strideH, strideW,
           paddingH, paddingW,
           scaleA, scaleB,
           targetGrad);
Z
zhangjinchao01 已提交
261 262 263
  CHECK_SYNC("hl_maxpool_backward");
}

264 265 266 267 268 269 270 271 272
__global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
                                 const int channels,
                                 const int height, const int width,
                                 const int pooledH, const int pooledW,
                                 const int sizeX, const int sizeY,
                                 const int strideH, const int strideW,
                                 const int padH, const int padW,
                                 real* tgtData) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
273 274 275 276
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
277 278 279 280 281 282 283 284 285 286 287 288
    int frameNum = index / pooledW / pooledH / channels;

    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
    int hend = min(hstart + sizeY, height + padH);
    int wend = min(wstart + sizeX, width + padW);
    int pool_size = (hend - hstart) * (wend - wstart);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    hend = min(hend, height);
    wend = min(wend, width);

Z
zhangjinchao01 已提交
289 290 291 292 293 294 295
    real aveval = 0;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        aveval += inputData[h * width + w];
      }
    }
296
    tgtData[index] = aveval / pool_size;
Z
zhangjinchao01 已提交
297 298 299
  }
}

300 301 302 303 304 305 306 307 308 309
void hl_avgpool_forward(const int frameCnt, const real* inputData,
                        const int channels,
                        const int height, const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW, real* tgtData) {
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  KeAvgPoolForward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
310 311
           (num_kernels, inputData, channels,
           height, width, pooledH, pooledW,
312 313
           sizeX, sizeY, strideH, strideW,
           paddingH, paddingW, tgtData);
Z
zhangjinchao01 已提交
314 315 316
  CHECK_SYNC("hl_avgpool_forward failed");
}

317 318 319 320 321 322 323 324 325 326
__global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
                                  const int channels, const int height,
                                  const int width,
                                  const int pooledH, const int pooledW,
                                  const int sizeX, const int sizeY,
                                  const int strideH, const int strideW,
                                  const int padH, const int padW,
                                  real scaleA, real scaleB,
                                  real* tgtGrad) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
327
  if (index < nthreads) {
328 329
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
330
    int offsetC = (index / width / height) % channels;
331 332 333 334 335 336
    int frameNum = index / width / height / channels;

    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
337 338 339 340 341 342
    real gradient = 0;
    outGrad += (frameNum * channels + offsetC) * pooledH * pooledW;

    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        // figure out the pooling size
343 344 345 346 347
        int hstart = ph * strideH - padH;
        int wstart = pw * strideW - padW;
        int hend = min(hstart + sizeY, height + padH);
        int wend = min(wstart + sizeX, width + padW);
        int poolsize = (hend - hstart) * (wend - wstart);
Z
zhangjinchao01 已提交
348 349 350 351 352 353 354
        gradient += outGrad[ph * pooledW + pw]/poolsize;
      }
    }
    tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
  }
}

355 356 357 358 359 360 361 362 363 364 365
void hl_avgpool_backward(const int frameCnt, const real* outGrad,
                         const int channels,
                         const int height, const int width,
                         const int pooledH, const int pooledW,
                         const int sizeX, const int sizeY,
                         const int strideH, const int strideW,
                         const int paddingH, const int paddingW,
                         real scaleA, real scaleB,
                         real* backGrad) {
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
366

367
  KeAvgPoolBackward <<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
368
           (num_kernels, outGrad, channels, height, width,
369 370 371 372 373
           pooledH, pooledW, sizeX, sizeY,
           strideH, strideW,
           paddingH, paddingW,
           scaleA, scaleB,
           backGrad);
Z
zhangjinchao01 已提交
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
  CHECK_SYNC("hl_avgpool_backward failed");
}

__global__ void KeCMRNormFillScale(size_t nthreads, const real* in,
                                   real* scale, size_t channels,
                                   size_t height, size_t width, size_t size,
                                   real alpha) {
  size_t index = threadIdx.x + blockIdx.x * blockDim.x;
  if (index < nthreads) {
    // find out the local offset
    size_t w = index % width;
    size_t h = (index / width) % height;
    size_t n = index / width / height;
    size_t offset = (n * channels * height + h) * width + w;
    size_t step = height * width;
    in += offset;
    scale += offset;
    size_t head = 0;
    size_t pre_pad = (size - 1) / 2;
    size_t post_pad = size - pre_pad - 1;
    real accum_scale = 0;
    // fill the scale at [n, :, h, w]
    // accumulate values
    while (head < post_pad) {
      accum_scale += in[head * step] * in[head * step];
      ++head;
    }
    // until we reach size, nothing needs to be subtracted
    while (head < size) {
      accum_scale += in[head * step] * in[head * step];
      scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
      ++head;
    }
    // both add and subtract
    while (head < channels) {
      accum_scale += in[head * step] * in[head * step];
      accum_scale -= in[(head - size) * step] * in[(head - size) * step];
      scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
      ++head;
    }
    // subtract only
    while (head < channels + post_pad) {
      accum_scale -= in[(head - size) * step] * in[(head - size) * step];
      scale[(head - post_pad) * step] = 1. + accum_scale * alpha;
      ++head;
    }
  }
}

 __global__ void KeCMRNormOutput(size_t nthreads, const real* in,
                                 const real* scale, real negative_beta,
                                 real* out) {
  size_t index = threadIdx.x + blockIdx.x * blockDim.x;
  if (index < nthreads) {
    out[index] = in[index] * pow(scale[index], negative_beta);
  }
}

void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale,
                        real* out, size_t channels,
                        size_t height, size_t width, size_t sizeX,
                        real alpha, real beta) {
  size_t threadsNum = frameCnt * height * width;
  size_t blocksX = (threadsNum + 1024 - 1) / 1024;
  size_t blocksY = 1;
  dim3 threads(1024, 1);
  dim3 grid(blocksX, blocksY);

  KeCMRNormFillScale<<<grid, threads, 0, STREAM_DEFAULT>>>
      (threadsNum, in, scale, channels, height, width, sizeX, alpha);

  threadsNum = frameCnt * height * width *channels;
  blocksX = (threadsNum + 1024 -1) / 1024;
  dim3 threads2(1024, 1);
  dim3 grid2(blocksX, blocksY);
  KeCMRNormOutput<<<grid2, threads2, 0, STREAM_DEFAULT>>>
           (threadsNum, in, scale, beta, out);
  CHECK_SYNC("hl_CMRNorm_forward");
}

__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data,
                              const real* top_data, const real* scale,
                              const real* top_diff, size_t channels,
                              size_t height, size_t width, size_t size,
                              real negative_beta, real cache_ratio,
                              real* bottom_diff ) {
  int index = threadIdx.x + blockIdx.x * blockDim.x;
  if (index < nthreads) {
    // find out the local offset
    size_t w = index % width;
    size_t h = (index / width) % height;
    size_t n = index / width / height;
    size_t offset = (n * channels * height + h) * width + w;
    size_t step = height * width;
    bottom_data += offset;
    top_data += offset;
    scale += offset;
    top_diff += offset;
    bottom_diff += offset;
    int head = 0;
    int pre_pad = size - (size + 1) / 2;
    int post_pad = size - pre_pad - 1;
    real accum_ratio = 0;
    // accumulate values
    while (head < post_pad) {
      accum_ratio += top_diff[head * step] *
        top_data[head * step] / scale[head * step];
      ++head;
    }
    // until we reach size, nothing needs to be subtracted
    while (head < size) {
      accum_ratio += top_diff[head * step] *
        top_data[head * step] / scale[head * step];
      bottom_diff[(head - post_pad) * step] +=
        top_diff[(head - post_pad) * step] *
        pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
        bottom_data[(head - post_pad) * step] * accum_ratio;
      ++head;
    }
    // both add and subtract
    while (head < channels) {
      accum_ratio += top_diff[head * step] * top_data[head * step] /
          scale[head * step];
      accum_ratio -= top_diff[(head - size) * step] *
          top_data[(head - size) * step] / scale[(head - size) * step];
      bottom_diff[(head - post_pad) * step] +=
        top_diff[(head - post_pad) * step] *
        pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
        bottom_data[(head - post_pad) * step] * accum_ratio;
      ++head;
    }
    // subtract only
    while (head < channels + post_pad) {
      accum_ratio -= top_diff[(head - size) * step] *
          top_data[(head - size) * step] / scale[(head - size) * step];
      bottom_diff[(head - post_pad) * step] +=
        top_diff[(head - post_pad) * step] *
        pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
        bottom_data[(head - post_pad) * step] * accum_ratio;
      ++head;
    }
  }
}

void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
                         const real* scale,
                         const real* outV, const real* outDiff,
                         real *inDiff, size_t channels,
                         size_t height, size_t width, size_t sizeX,
                         real alpha, real beta) {
  size_t threadsNum = frameCnt * height * width;
L
liaogang 已提交
525
  size_t blocksX = (threadsNum + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
526 527 528 529 530 531 532 533
  size_t blocksY = 1;
  dim3 threads(1024, 1);
  dim3 grid(blocksX, blocksY);
  KeCMRNormDiff <<<grid, threads, 0, STREAM_DEFAULT>>>
           (threadsNum, inV, outV, scale, outDiff, channels,
           height, width, sizeX, alpha, beta, inDiff);
  CHECK_SYNC("hl_CMRNorm_backward");
}
L
liaogang 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665

__global__ void KeBilinearInterpFw(const size_t nthreads,
                                   const real* in,
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if(tid < nthreads) {
    int outIdH = tid / (outputW / numChannels);
    int outIdW = tid % (outputW / numChannels);

    int inIdH = ratioH * (outIdW / outImgW);
    int hId = (inIdH < inImgH - 1) ? 1 : 0;
    real hlambda = ratioH * (outIdW / outImgW) - inIdH;

    int inIdW = ratioW * (tid % outImgW);
    int wId = (inIdW < inImgW - 1) ? 1 : 0;
    real wlambda = ratioW * (tid % outImgW) - inIdW;

    const real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW];
    real* outPos = &out[outIdH * outputW + outIdW];
    for (int c = 0; c < numChannels; ++c) {
      // bilinear interpolation
      outPos[0] = (1.f - hlambda) *
        ((1.f - wlambda) * inPos[0] + wlambda * inPos[wId]) + 
        hlambda * ((1.f - wlambda) * inPos[hId * inImgW] +
        wlambda * inPos[hId * inImgW + wId]);
      inPos += inImgH * inImgW;
      outPos += outImgH * outImgW;
    }
  }
}

void hl_bilinear_forward(const real* inData,
                         const size_t inImgH,
                         const size_t inImgW,
                         const size_t inputH,
                         const size_t inputW,
                         real* outData,
                         const size_t outImgH,
                         const size_t outImgW,
                         const size_t outputH,
                         const size_t outputW,
                         const size_t numChannels) {
  int threadNum = outputH * outImgH * outImgW;
  int blocks = (threadNum + 1024 - 1) / 1024;

  real ratioH = (outImgH > 1) ?
      static_cast<float>(inImgH - 1) / (outImgH - 1) : 0.f;
  real ratioW = (outImgW > 1) ?
      static_cast<float>(inImgW - 1) / (outImgW - 1) : 0.f;

  KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
    threadNum, inData, inImgH, inImgW, inputH, inputW, outData,
    outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW);
  CHECK_SYNC("hl_bilinear_forward failed");
}

__global__ void KeBilinearInterpBw(const size_t nthreads,
                                   real* in,
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   const real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;

  if(tid < nthreads) {
    int outIdH = tid / (outputW / numChannels);
    int outIdW = tid % (outputW / numChannels);

    int inIdH = ratioH * (outIdW / outImgW);
    int hId = (inIdH < inImgH - 1) ? 1 : 0;
    real hlambda = ratioH * (outIdW / outImgW) - inIdH;

    int inIdW = ratioW * (tid % outImgW);
    int wId = (inIdW < inImgW - 1) ? 1 : 0;
    real wlambda = ratioW * (tid % outImgW) - inIdW;

    const real* outPos = &out[outIdH * outputW + outIdW];
    real* inPos = &in[outIdH * inputW + inIdH * inImgW + inIdW];
    for (int c = 0; c < numChannels; ++c) {
      atomicAdd(&inPos[0], (1.f - hlambda) * (1.f - wlambda) * outPos[0]);
      atomicAdd(&inPos[wId], (1.f - hlambda) * wlambda * outPos[0]);
      atomicAdd(&inPos[hId * inImgW], hlambda * (1.f - wlambda) * outPos[0]);
      atomicAdd(&inPos[hId * inImgW + wId], hlambda * wlambda * outPos[0]);
      inPos += inImgH * inImgW;
      outPos += outImgH * outImgW;
    }
  }
}

void hl_bilinear_backward(real* inGrad,
                          const size_t inImgH,
                          const size_t inImgW,
                          const size_t inputH,
                          const size_t inputW,
                          const real* outGrad,
                          const size_t outImgH,
                          const size_t outImgW,
                          const size_t outputH,
                          const size_t outputW,
                          const size_t numChannels) {
  int threadNum = outputH * outImgH * outImgW;
  int blocks = (threadNum + 1024 - 1) / 1024;
 
  real ratioH = (outImgH > 1) ?
      static_cast<float>(inImgH - 1) / (outImgH - 1) : 0.f;
  real ratioW = (outImgW > 1) ?
      static_cast<float>(inImgW - 1) / (outImgW - 1) : 0.f;

  KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
    threadNum, inGrad, inImgH, inImgW, inputH, inputW, outGrad,
    outImgH, outImgW, outputH, outputW, numChannels, ratioH, ratioW);
  CHECK_SYNC("hl_bilinear_backward failed");
}