hl_cuda_matrix.cu 26.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
L
liaogang 已提交
16 17
#include "hl_device_functions.cuh"
#include "hl_gpu_matrix_kernel.cuh"
Z
zhangjinchao01 已提交
18 19
#include "hl_matrix.h"
#include "hl_matrix_apply.cuh"
L
liaogang 已提交
20
#include "hl_matrix_ops.cuh"
Z
zhangjinchao01 已提交
21
#include "hl_sequence.h"
22
#include "hl_sparse.ph"
Z
zhangjinchao01 已提交
23 24 25
#include "paddle/utils/Logging.h"

DEFINE_MATRIX_UNARY_OP(Zero, a = 0);
L
liaogang 已提交
26 27 28 29
DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1 * a + p2 * b);
void hl_matrix_add(real* A_d,
                   real* B_d,
                   real* C_d,
Z
zhangjinchao01 已提交
30 31 32 33 34 35 36 37
                   int dimM,
                   int dimN,
                   real alpha,
                   real beta) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
38 39 40 41 42 43 44 45 46 47
  hl_gpu_apply_ternary_op<real, ternary::_add<real>, 0, 0>(
      ternary::_add<real>(alpha, beta),
      A_d,
      B_d,
      C_d,
      dimM,
      dimN,
      dimN,
      dimN,
      dimN);
Z
zhangjinchao01 已提交
48 49 50
  CHECK_SYNC("hl_matrix_add failed");
}

51
#ifdef PADDLE_TYPE_DOUBLE
L
liaogang 已提交
52
#define THRESHOLD 128
Z
zhangjinchao01 已提交
53
#else
L
liaogang 已提交
54
#define THRESHOLD 64
Z
zhangjinchao01 已提交
55
#endif
L
liaogang 已提交
56 57 58 59 60 61 62 63
__device__ __forceinline__ void findMax(real* I,
                                        real* dfMax_s,
                                        int blockSize,
                                        int base,
                                        int curIdx,
                                        int nextIdx,
                                        int dimN,
                                        real* max) {
Z
zhangjinchao01 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  dfMax_s[base] = -1.0e20;
  while (curIdx < dimN) {
    if (dfMax_s[base] < I[nextIdx]) {
      dfMax_s[base] = I[nextIdx];
    }
    nextIdx += blockSize;
    curIdx += blockSize;
  }
  __syncthreads();

  for (int stride = blockSize >> 1; stride > 0; stride >>= 1) {
    __syncthreads();
    if (base < stride) {
      nextIdx = base + stride;
      if (dfMax_s[base] < dfMax_s[nextIdx]) {
L
liaogang 已提交
79
        dfMax_s[base] = dfMax_s[nextIdx];
Z
zhangjinchao01 已提交
80 81 82 83
      }
    }
  }

L
liaogang 已提交
84
  if (0 == base) {
Z
zhangjinchao01 已提交
85 86 87 88 89
    max[0] = dfMax_s[0];
  }
  __syncthreads();
}

L
liaogang 已提交
90 91 92 93 94 95 96
__device__ __forceinline__ void subMaxAndExp(real* I,
                                             real* O,
                                             int curIdx,
                                             int nextIdx,
                                             int blockSize,
                                             int dimN,
                                             real max) {
Z
zhangjinchao01 已提交
97 98 99 100 101 102 103
  real val;
  while (curIdx < dimN) {
    val = I[nextIdx] - max;
    if (val < -THRESHOLD) {
      val = -THRESHOLD;
    }
    I[nextIdx] = val;
104
#ifndef PADDLE_TYPE_DOUBLE
Z
zhangjinchao01 已提交
105 106 107 108 109 110 111 112 113 114
    O[nextIdx] = __expf(val);
#else
    O[nextIdx] = exp(val);
#endif
    nextIdx += blockSize;
    curIdx += blockSize;
  }
  __syncthreads();
}

L
liaogang 已提交
115 116 117 118 119 120 121
__device__ __forceinline__ void valueSum(real* O,
                                         real* dfMax_s,
                                         int blockSize,
                                         int base,
                                         int curIdx,
                                         int nextIdx,
                                         int dimN) {
Z
zhangjinchao01 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  dfMax_s[base] = 0;
  while (curIdx < dimN) {
    dfMax_s[base] += O[nextIdx];
    nextIdx += blockSize;
    curIdx += blockSize;
  }
  __syncthreads();

  for (int stride = blockSize >> 1; stride > 0; stride >>= 1) {
    __syncthreads();
    if (base < stride) {
      nextIdx = base + stride;
      dfMax_s[base] += dfMax_s[nextIdx];
    }
  }
  __syncthreads();
}

L
liaogang 已提交
140 141
__device__ __forceinline__ void divSum(
    real* O, real sum, int curIdx, int nextIdx, int blockSize, int dimN) {
Z
zhangjinchao01 已提交
142 143 144 145 146 147 148
  while (curIdx < dimN) {
    O[nextIdx] /= sum;
    nextIdx += blockSize;
    curIdx += blockSize;
  }
}

L
liaogang 已提交
149 150 151 152 153 154 155 156
__device__ __forceinline__ void softmax(real* I,
                                        real* O,
                                        real* dfMax_s,
                                        int blockSize,
                                        int base,
                                        int curIdx,
                                        int nextIdx,
                                        int dimN) {
Z
zhangjinchao01 已提交
157 158 159
  __shared__ real max;

  // find the max number
L
liaogang 已提交
160
  findMax(I, dfMax_s, blockSize, base, curIdx, nextIdx, dimN, &max);
Z
zhangjinchao01 已提交
161 162 163 164 165 166 167 168 169 170 171 172

  // sub max Value and do Exp operation
  subMaxAndExp(I, O, base, nextIdx, blockSize, dimN, max);

  // add dimN values into blockDim.x buffer
  // sum is in dfMax_s[0]
  valueSum(O, dfMax_s, blockSize, base, curIdx, nextIdx, dimN);

  // divided by sum
  divSum(O, dfMax_s[0], curIdx, nextIdx, blockSize, dimN);
}

L
liaogang 已提交
173 174
template <int blockSize>
__global__ void KeMatrixSoftMax(real* O, real* I, int dimN) {
Z
zhangjinchao01 已提交
175 176 177 178 179 180 181 182
  int base = threadIdx.x;
  __shared__ real dfMax_s[blockSize];
  int nextIdx = blockIdx.x * dimN + base;
  int curIdx = base;

  softmax(I, O, dfMax_s, blockSize, base, curIdx, nextIdx, dimN);
}

L
liaogang 已提交
183
void hl_matrix_softmax(real* A_d, real* C_d, int dimM, int dimN) {
Z
zhangjinchao01 已提交
184 185 186 187 188
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

  dim3 block(512, 1);
  dim3 grid(dimM, 1);
L
liaogang 已提交
189
  KeMatrixSoftMax<512><<<grid, block, 0, STREAM_DEFAULT>>>(C_d, A_d, dimN);
Z
zhangjinchao01 已提交
190 191 192
  CHECK_SYNC("hl_matrix_softmax failed");
}

L
liaogang 已提交
193 194
template <int blockSize>
__global__ void KeSequenceSoftMax(real* O, real* I, const int* index) {
Z
zhangjinchao01 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207
  int base = threadIdx.x;
  int bid = blockIdx.x;
  __shared__ real dfMax_s[blockSize];

  int start = index[bid];
  int dimN = index[bid + 1] - start;

  int nextIdx = start + base;
  int curIdx = base;

  softmax(I, O, dfMax_s, blockSize, base, curIdx, nextIdx, dimN);
}

L
liaogang 已提交
208 209
void hl_sequence_softmax_forward(real* A_d,
                                 real* C_d,
Z
zhangjinchao01 已提交
210 211 212 213 214 215 216
                                 const int* index,
                                 int numSequence) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

  dim3 block(512, 1);
  dim3 grid(numSequence, 1);
L
liaogang 已提交
217
  KeSequenceSoftMax<512><<<grid, block, 0, STREAM_DEFAULT>>>(C_d, A_d, index);
Z
zhangjinchao01 已提交
218 219 220
  CHECK_SYNC("hl_sequence_softmax_forward failed");
}

L
liaogang 已提交
221 222 223 224
__global__ void KeMatrixDerivative(
    real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {
  int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
  int colIdx = blockIdx.y * blockDim.y + threadIdx.y;
Z
zhangjinchao01 已提交
225 226 227
  int index;

  if (rowIdx < dimM && colIdx < dimN) {
L
liaogang 已提交
228
    index = rowIdx * dimN + colIdx;
Z
zhangjinchao01 已提交
229 230 231 232
    grad_d[index] = output_d[index] * (grad_d[index] - sftmaxSum_d[rowIdx]);
  }
}

L
liaogang 已提交
233 234
void hl_matrix_softmax_derivative(
    real* grad_d, real* output_d, real* sftmaxSum_d, int dimM, int dimN) {
Z
zhangjinchao01 已提交
235 236 237 238 239
  CHECK_NOTNULL(grad_d);
  CHECK_NOTNULL(output_d);
  CHECK_NOTNULL(sftmaxSum_d);

  int blocksX = (dimM + 0) / 1;
L
liaogang 已提交
240
  int blocksY = (dimN + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
241 242 243
  dim3 threads(1, 1024);
  dim3 grid(blocksX, blocksY);

L
liaogang 已提交
244 245
  KeMatrixDerivative<<<grid, threads, 0, STREAM_DEFAULT>>>(
      grad_d, output_d, sftmaxSum_d, dimM, dimN);
Z
zhangjinchao01 已提交
246 247 248
  CHECK_SYNC("hl_matrix_softmax_derivative failed");
}

L
liaogang 已提交
249 250
__global__ void KeMatrixMultiBinaryCrossEntropy(
    real* output, real* entropy, int* row, int* col, int dimM, int dimN) {
251 252
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < dimM) {
L
liaogang 已提交
253
    for (int i = 0; i < dimN; i++) {
254 255
      entropy[index] -= log(1 - output[index * dimN + i]);
    }
L
liaogang 已提交
256
    int* row_col = col + row[index];
257
    int col_num = row[index + 1] - row[index];
L
liaogang 已提交
258
    for (int i = 0; i < col_num; i++) {
259 260 261 262 263 264 265 266 267 268 269 270 271 272
      real o = output[index * dimN + row_col[i]];
      entropy[index] -= log(o / (1 - o));
    }
  }
}

void hl_matrix_multi_binary_cross_entropy(real* output,
                                          real* entropy,
                                          hl_sparse_matrix_s csr_mat,
                                          int dimM,
                                          int dimN) {
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(entropy);
  CHECK_NOTNULL(csr_mat);
H
Haonan 已提交
273
  CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
274 275 276 277 278
  int n_threads = 1024;
  int blocks = (dimM + n_threads - 1) / n_threads;
  dim3 threads(n_threads);
  dim3 grid(blocks);
  hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
L
liaogang 已提交
279 280
  KeMatrixMultiBinaryCrossEntropy<<<grid, threads, 0, STREAM_DEFAULT>>>(
      output, entropy, mat->csr_row, mat->csr_col, dimM, dimN);
281 282 283
  CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed");
}

L
liaogang 已提交
284 285
__global__ void KeMatrixMultiBinaryCrossEntropyBp(
    real* output, real* grad, int* row, int* col, int dimM, int dimN) {
286 287
  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (row_idx < dimM) {
L
liaogang 已提交
288
    for (int i = 0; i < dimN; i++) {
289 290 291 292
      int index = row_idx * dimN + i;
      grad[index] += 1.0 / (1 - output[index]);
    }
    int col_num = row[row_idx + 1] - row[row_idx];
L
liaogang 已提交
293 294
    int* row_col = col + row[row_idx];
    for (int i = 0; i < col_num; i++) {
295 296 297 298 299 300
      int index = row_idx * dimN + row_col[i];
      grad[index] -= 1.0 / (output[index] * (1 - output[index]));
    }
  }
}

L
liaogang 已提交
301 302
void hl_matrix_multi_binary_cross_entropy_bp(
    real* output, real* grad, hl_sparse_matrix_s csr_mat, int dimM, int dimN) {
303 304 305
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(grad);
  CHECK_NOTNULL(csr_mat);
H
Haonan 已提交
306
  CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
307 308 309 310 311
  int n_threads = 1024;
  int blocks = (dimM + n_threads - 1) / n_threads;
  dim3 threads(n_threads);
  dim3 grid(blocks);
  hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
L
liaogang 已提交
312 313
  KeMatrixMultiBinaryCrossEntropyBp<<<grid, threads, 0, STREAM_DEFAULT>>>(
      output, grad, mat->csr_row, mat->csr_col, dimM, dimN);
314 315 316
  CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed");
}

L
liaogang 已提交
317 318
__global__ void KeMatrixCrossEntropy(
    real* O, real* E, int* label, int dimM, int dimN) {
Z
zhangjinchao01 已提交
319 320 321 322 323 324 325 326 327
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int newBase;
  if (index < dimM) {
    newBase = label[index];
    newBase = newBase % dimN;
    E[index] = -log(O[index * dimN + newBase]);
  }
}

L
liaogang 已提交
328 329
void hl_matrix_cross_entropy(
    real* A_d, real* C_d, int* label_d, int dimM, int dimN) {
Z
zhangjinchao01 已提交
330 331 332 333 334 335
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

  int blocks = (dimM + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);
L
liaogang 已提交
336 337
  KeMatrixCrossEntropy<<<grid, threads, 0, STREAM_DEFAULT>>>(
      A_d, C_d, label_d, dimM, dimN);
Z
zhangjinchao01 已提交
338 339 340
  CHECK_SYNC("hl_matrix_cross_entropy failed");
}

L
liaogang 已提交
341 342 343 344
__global__ void KeMatrixCrossEntropyBp(
    real* grad_d, real* output_d, int* label_d, int dimM, int dimN) {
  int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
  int colIdx = blockIdx.y * blockDim.y + threadIdx.y;
Z
zhangjinchao01 已提交
345 346
  int index;
  if (rowIdx < dimM && colIdx < dimN) {
L
liaogang 已提交
347
    index = rowIdx * dimN + colIdx;
Z
zhangjinchao01 已提交
348 349 350 351 352 353
    if (label_d[rowIdx] == colIdx) {
      grad_d[index] -= 1.0f / output_d[index];
    }
  }
}

L
liaogang 已提交
354 355
void hl_matrix_cross_entropy_bp(
    real* grad_d, real* output_d, int* label_d, int dimM, int dimN) {
Z
zhangjinchao01 已提交
356 357 358 359
  CHECK_NOTNULL(grad_d);
  CHECK_NOTNULL(output_d);
  CHECK_NOTNULL(label_d);

L
liaogang 已提交
360 361
  int blocksX = (dimM + 0) / 1;
  int blocksY = (dimN + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
362 363
  dim3 threads(1, 1024);
  dim3 grid(blocksX, blocksY);
L
liaogang 已提交
364 365
  KeMatrixCrossEntropyBp<<<grid, threads, 0, STREAM_DEFAULT>>>(
      grad_d, output_d, label_d, dimM, dimN);
Z
zhangjinchao01 已提交
366 367 368 369
  CHECK_SYNC("hl_matrix_cross_entropy_bp failed");
}

void hl_matrix_zero_mem(real* data, int num) {
L
liaogang 已提交
370
  hl_gpu_apply_unary_op(unary::Zero<real>(), data, 1, num, num);
Z
zhangjinchao01 已提交
371 372 373 374 375 376 377 378 379 380 381 382
}

__global__ void KeParamReluForward(real* output,
                                   real* input,
                                   real* w,
                                   int width,
                                   int height,
                                   int partial_sum) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int ty = blockIdx.y * blockDim.y + threadIdx.y;
  if (tx < width && ty < height) {
    int index = ty * width + tx;
L
liaogang 已提交
383 384
    output[index] =
        input[index] > 0 ? input[index] : input[index] * w[tx / partial_sum];
Z
zhangjinchao01 已提交
385 386 387 388 389 390 391 392 393 394 395 396 397 398
  }
}

void hl_param_relu_forward(real* output,
                           real* input,
                           real* w,
                           int width,
                           int height,
                           int partial_sum) {
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(w);
  dim3 threads(16, 16);
  int blockX = (width + 16 - 1) / 16;
L
liaogang 已提交
399
  int blockY = (height + 16 - 1) / 16;
Z
zhangjinchao01 已提交
400
  dim3 grid(blockX, blockY);
L
liaogang 已提交
401 402
  KeParamReluForward<<<grid, threads, 0, STREAM_DEFAULT>>>(
      output, input, w, width, height, partial_sum);
Z
zhangjinchao01 已提交
403 404 405
  CHECK_SYNC("hl_param_relu_forward failed");
}

L
liaogang 已提交
406
template <int blockSize>
Z
zhangjinchao01 已提交
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
__global__ void KeParamReluBackWardW(real* grad_w,
                                     real* grad_o,
                                     real* input,
                                     int width,
                                     int height,
                                     int partial_sum) {
  const int tid = threadIdx.x;
  __shared__ real temp[blockSize];
  grad_o += partial_sum * blockIdx.x;
  input += partial_sum * blockIdx.x;
  real tmp = 0.0;
  for (int index = tid; index < partial_sum * height; index += blockSize) {
    int row = index / partial_sum;
    int offset = row * width + (index - row * partial_sum);
    if (input[offset] < 0) {
      tmp += grad_o[offset] * input[offset];
    }
  }
  temp[tid] = tmp;
  __syncthreads();
  for (int s = blockSize / 2; s > 0; s >>= 1) {
    if (tid < s) {
      temp[tid] += temp[tid + s];
    }
    __syncthreads();
  }
  if (tid == 0) {
    grad_w[blockIdx.x] += temp[0];
  }
}

void hl_param_relu_backward_w(real* grad_w,
                              real* grad_o,
                              real* input,
                              int width,
                              int height,
                              int partial_sum) {
  CHECK_NOTNULL(grad_w);
  CHECK_NOTNULL(grad_o);
  CHECK_NOTNULL(input);
  const int blockSize = 1024;
  int grid_num = width / partial_sum;
  dim3 threads(blockSize, 1);
  dim3 grid(grid_num, 1);
L
liaogang 已提交
451 452
  KeParamReluBackWardW<blockSize><<<grid, threads, 0, STREAM_DEFAULT>>>(
      grad_w, grad_o, input, width, height, partial_sum);
Z
zhangjinchao01 已提交
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
  CHECK_SYNC("hl_param_relu_backward_w failed");
}

__global__ void KeParamReluBackwardDiff(real* grad_o,
                                        real* input,
                                        real* w,
                                        real* diff,
                                        int width,
                                        int height,
                                        int partial_sum) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int ty = blockIdx.y * blockDim.y + threadIdx.y;
  if (tx < width && ty < height) {
    int index = ty * width + tx;
    diff[index] += grad_o[index] * (input[index] > 0 ? 1 : w[tx / partial_sum]);
  }
}

void hl_param_relu_backward_diff(real* grad_o,
                                 real* data,
                                 real* w,
                                 real* diff,
                                 int width,
                                 int height,
                                 int partial_sum) {
  CHECK_NOTNULL(grad_o);
  CHECK_NOTNULL(data);
  CHECK_NOTNULL(w);
  CHECK_NOTNULL(diff);
  dim3 threads(16, 16);
  int blockX = (width + 16 - 1) / 16;
L
liaogang 已提交
484
  int blockY = (height + 16 - 1) / 16;
Z
zhangjinchao01 已提交
485
  dim3 grid(blockX, blockY);
L
liaogang 已提交
486 487
  KeParamReluBackwardDiff<<<grid, threads, 0, STREAM_DEFAULT>>>(
      grad_o, data, w, diff, width, height, partial_sum);
Z
zhangjinchao01 已提交
488 489 490
  CHECK_SYNC("hl_param_relu_backward_diff failed");
}

L
liaogang 已提交
491 492
__global__ void KeMatrixAddSharedBias(
    real* A, real* B, const int channel, const int M, const int N, real scale) {
493 494 495 496
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int dim = N / channel;
  if (index < M * N) {
    int i = index % N;
H
Haonan 已提交
497
    i = i / dim;
498 499 500 501 502 503 504 505 506 507 508 509
    A[index] += scale * B[i];
  }
}

void hl_matrix_add_shared_bias(real* A_d,
                               real* B_d,
                               const int channel,
                               const int dimM,
                               const int dimN,
                               real scale) {
  const int blocks = 512;
  const int grids = DIVUP(dimM * dimN, blocks);
L
liaogang 已提交
510 511
  KeMatrixAddSharedBias<<<grids, blocks, 0, STREAM_DEFAULT>>>(
      A_d, B_d, channel, dimM, dimN, scale);
512 513 514 515
  CHECK_SYNC("hl_matrix_add_shared_bias failed");
}

template <int blockSize>
L
liaogang 已提交
516 517
__global__ void KeMatrixCollectSharedBias(real* B,
                                          real* A,
518 519 520 521 522 523
                                          const int channel,
                                          const int M,
                                          const int N,
                                          const int dim,
                                          const int limit,
                                          real scale) {
H
Haonan 已提交
524
  if (dim < limit) {
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < channel) {
      real sum = 0.0;
      for (int i = 0; i < M; ++i) {
        for (int j = 0; j < dim; ++j) {
          sum += A[i * N + index * dim + j];
        }
      }
      B[index] += scale * sum;
    }
  } else {
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    __shared__ real smem[blockSize];
    real sum = 0.0;
    for (int j = 0; j < ((dim * M + blockSize - 1) / blockSize); ++j) {
      int n = j * blockSize + tid;
      int m = n / dim;
      int w = n % dim;
L
liaogang 已提交
544
      smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0;
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
      __syncthreads();
      simpleReduce(smem, tid, blockSize);
      sum += smem[0];
    }
    if (tid == 0) {
      B[bid] += scale * sum;
    }
  }
}

void hl_matrix_collect_shared_bias(real* B_d,
                                   real* A_d,
                                   const int channel,
                                   const int dimM,
                                   const int dimN,
                                   real scale) {
  const int dim = dimN / channel;
  const int blocks = 256;
  const int limit = 64;
  int grids = (dimM * dim) < limit ? DIVUP(channel, blocks) : channel;

L
liaogang 已提交
566 567
  KeMatrixCollectSharedBias<blocks><<<grids, blocks, 0, STREAM_DEFAULT>>>(
      B_d, A_d, channel, dimM, dimN, dim, limit, scale);
568 569
  CHECK_SYNC("hl_matrix_collect_shared_bias failed");
}
H
Haonan 已提交
570

L
liaogang 已提交
571 572 573 574 575 576 577 578 579 580
__global__ void keMatrixRotate(
    real* mat, real* matRot, int dimM, int dimN, bool clockWise) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < dimM * dimN) {
    int i = idx / dimN;
    int j = idx % dimN;
    if (clockWise) {
      matRot[j * dimM + i] = mat[(dimM - i - 1) * dimN + j];
    } else {
      matRot[j * dimM + i] = mat[i * dimN + (dimN - j - 1)];
H
Haonan 已提交
581
    }
L
liaogang 已提交
582
  }
H
Haonan 已提交
583 584
}

L
liaogang 已提交
585 586 587 588 589 590 591 592 593
void hl_matrix_rotate(
    real* mat, real* matRot, int dimM, int dimN, bool clockWise) {
  CHECK_NOTNULL(mat);
  CHECK_NOTNULL(matRot);
  const int threads = 512;
  const int blocks = DIVUP(dimM * dimN, threads);
  keMatrixRotate<<<blocks, threads, 0, STREAM_DEFAULT>>>(
      mat, matRot, dimM, dimN, clockWise);
  CHECK_SYNC("hl_matrix_rotate failed");
H
Haonan 已提交
594
}
C
chengduoZH 已提交
595

C
chengduoZH 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
__global__ void keMatrixVol2Col(int num_kernels,
                                real* dataSrc,
                                real* dataDst,
                                int depth,
                                int height,
                                int width,
                                int filterD,
                                int filterH,
                                int filterW,
                                int strideD,
                                int strideH,
                                int strideW,
                                int paddingD,
                                int paddingH,
                                int paddingW,
                                int depth_col,
                                int height_col,
                                int width_col) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
616
    int w_out = index % width_col;
C
chengduoZH 已提交
617
    int h_out = (index / width_col) % height_col;
C
chengduoZH 已提交
618 619 620 621 622 623 624
    int d_out = (index / width_col / height_col) % depth_col;
    int channel_in = index / width_col / height_col / depth_col;
    int channel_out = channel_in * filterD * filterH * filterW;
    int w_in = w_out * strideW - paddingW;
    int h_in = h_out * strideH - paddingH;
    int d_in = d_out * strideD - paddingD;

C
chengduoZH 已提交
625 626 627
    dataDst +=
        ((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
        w_out;
C
chengduoZH 已提交
628 629 630 631 632 633 634
    dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
    for (int k = 0; k < filterD; ++k) {
      for (int i = 0; i < filterH; ++i) {
        for (int j = 0; j < filterW; ++j) {
          int d = d_in + k;
          int h = h_in + i;
          int w = w_in + j;
C
chengduoZH 已提交
635 636 637 638
          *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
                      w < width)
                         ? dataSrc[(k * height + i) * width + j]
                         : 0;
C
chengduoZH 已提交
639 640 641 642 643 644 645 646
          dataDst += depth_col * height_col * width_col;
        }
      }
    }
  }
}

void hl_matrix_vol2Col(real* dataSrc,
C
chengduoZH 已提交
647 648 649 650 651 652 653 654 655 656 657 658 659 660
                       int channels,
                       int depth,
                       int height,
                       int width,
                       int filterD,
                       int filterH,
                       int filterW,
                       int strideD,
                       int strideH,
                       int strideW,
                       int paddingD,
                       int paddingH,
                       int paddingW,
                       real* dataDst) {
C
chengduoZH 已提交
661 662 663 664 665 666 667 668
  int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
  int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
  int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
  int num_kernels = channels * depth_col * height_col * width_col;

  const int threads = 512;
  const int blocks = DIVUP(num_kernels, threads);

C
chengduoZH 已提交
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
  keMatrixVol2Col<<<blocks, threads>>>(num_kernels,
                                       dataSrc,
                                       dataDst,
                                       depth,
                                       height,
                                       width,
                                       filterD,
                                       filterH,
                                       filterW,
                                       strideD,
                                       strideH,
                                       strideW,
                                       paddingD,
                                       paddingH,
                                       paddingW,
                                       depth_col,
                                       height_col,
                                       width_col);
C
chengduoZH 已提交
687 688 689
  CHECK_SYNC("hl_matrix_vol2Col failed");
}

C
chengduoZH 已提交
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
__global__ void keMatrixCol2Vol(int num_kernels,
                                real* dataDst,
                                real* dataSrc,
                                int depth,
                                int height,
                                int width,
                                int filterD,
                                int filterH,
                                int filterW,
                                int strideD,
                                int strideH,
                                int strideW,
                                int paddingD,
                                int paddingH,
                                int paddingW,
                                int depth_col,
                                int height_col,
                                int width_col,
                                real alpha,
                                real beta) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
C
chengduoZH 已提交
711
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
712 713
    real srcVal = 0;
    real dstVal = dataDst[index];
C
chengduoZH 已提交
714 715 716
    int w = index % width + paddingW;
    int h = (index / width) % height + paddingH;
    int d = (index / width / height) % depth + paddingD;
C
chengduoZH 已提交
717
    int c = index / width / height / depth;
C
chengduoZH 已提交
718 719 720 721 722 723 724 725
    // compute the start and end of the output
    int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
    int w_col_end = min(w / strideW + 1, width_col);
    int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1;
    int h_col_end = min(h / strideH + 1, height_col);
    int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
    int d_col_end = min(d / strideD + 1, depth_col);

C
chengduoZH 已提交
726 727 728
    int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
                  h * filterW + w) *
                 depth_col * height_col * width_col;
C
chengduoZH 已提交
729

C
chengduoZH 已提交
730 731 732 733
    int coeff_d_col =
        (1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
    int coeff_h_col =
        (1 - strideH * filterW * depth_col * height_col) * width_col;
C
chengduoZH 已提交
734 735 736 737 738
    int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);

    for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
C
chengduoZH 已提交
739 740
          srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
                            w_col * coeff_w_col];
C
chengduoZH 已提交
741 742 743
        }
      }
    }
C
chengduoZH 已提交
744
    dataDst[index] = alpha * srcVal + beta * dstVal;
C
chengduoZH 已提交
745 746 747 748
  }
}

void hl_matrix_col2Vol(real* dataDst,
C
chengduoZH 已提交
749 750 751 752 753 754 755 756 757 758 759 760 761
                       int channels,
                       int depth,
                       int height,
                       int width,
                       int filterD,
                       int filterH,
                       int filterW,
                       int strideD,
                       int strideH,
                       int strideW,
                       int paddingD,
                       int paddingH,
                       int paddingW,
C
chengduoZH 已提交
762
                       real* dataSrc,
C
chengduoZH 已提交
763 764
                       real alpha,
                       real beta) {
C
chengduoZH 已提交
765 766 767 768 769 770 771 772
  int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
  int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
  int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
  int num_kernels = channels * depth * height * width;

  const int threads = 512;
  const int blocks = DIVUP(num_kernels, threads);

C
chengduoZH 已提交
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
  keMatrixCol2Vol<<<blocks, threads>>>(num_kernels,
                                       dataDst,
                                       dataSrc,
                                       depth,
                                       height,
                                       width,
                                       filterD,
                                       filterH,
                                       filterW,
                                       strideD,
                                       strideH,
                                       strideW,
                                       paddingD,
                                       paddingH,
                                       paddingW,
                                       depth_col,
                                       height_col,
                                       width_col,
                                       alpha,
                                       beta);
C
chengduoZH 已提交
793 794 795

  CHECK_SYNC("hl_matrix_col2Vol failed");
}