hl_cuda_aggregate.cu 7.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

L
liaogang 已提交
15
#include "hl_aggregate.h"
Z
zhangjinchao01 已提交
16 17 18 19
#include "hl_base.h"
#include "hl_cuda.h"
#include "hl_cuda.ph"
#include "hl_matrix_base.cuh"
L
liaogang 已提交
20
#include "hl_thread.ph"
Z
zhangjinchao01 已提交
21 22 23 24 25
#include "paddle/utils/Logging.h"

/**
 * @brief   matrix row operator.
 */
L
liaogang 已提交
26 27
template <class Agg, int blockSize>
__global__ void KeMatrixRowOp(Agg agg, real *E, real *Sum, int dimN) {
Z
zhangjinchao01 已提交
28
  __shared__ real sum_s[blockSize];
L
liaogang 已提交
29 30 31
  int cnt = (dimN + blockSize - 1) / blockSize;
  int rowId = blockIdx.x + blockIdx.y * gridDim.x;
  int index = rowId * dimN;
Z
zhangjinchao01 已提交
32 33 34 35 36 37 38 39 40 41 42
  int tid = threadIdx.x;
  int lmt = tid;

  real tmp = agg.init();
  for (int ii = 0; ii < cnt && lmt < dimN; ii++) {
    tmp = agg(tmp, E[index + lmt]);
    lmt += blockSize;
  }
  sum_s[tid] = tmp;
  __syncthreads();

L
liaogang 已提交
43
  for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
Z
zhangjinchao01 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56
    if (tid < stride) {
      sum_s[tid] = agg(sum_s[tid], sum_s[tid + stride]);
    }
    __syncthreads();
  }
  __syncthreads();

  if (tid == 0) {
    Sum[rowId] = sum_s[0];
  }
}

template <class Agg>
L
liaogang 已提交
57
void hl_matrix_row_op(Agg agg, real *A_d, real *C_d, int dimM, int dimN) {
Z
zhangjinchao01 已提交
58 59 60 61 62
  int blocksX = dimM;
  int blocksY = 1;
  dim3 threads(128, 1);
  dim3 grid(blocksX, blocksY);

L
liaogang 已提交
63 64
  KeMatrixRowOp<Agg, 128><<<grid, threads, 0, STREAM_DEFAULT>>>(
      agg, A_d, C_d, dimN);
Z
zhangjinchao01 已提交
65 66 67 68 69 70
}

void hl_matrix_row_sum(real *A_d, real *C_d, int dimM, int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
71
  hl_matrix_row_op(aggregate::sum(), A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
72 73 74 75 76 77 78
  CHECK_SYNC("hl_matrix_row_sum failed");
}

void hl_matrix_row_max(real *A_d, real *C_d, int dimM, int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
79
  hl_matrix_row_op(aggregate::max(), A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
80 81 82 83 84 85 86
  CHECK_SYNC("hl_matrix_row_max failed");
}

void hl_matrix_row_min(real *A_d, real *C_d, int dimM, int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
87
  hl_matrix_row_op(aggregate::min(), A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
88 89 90 91 92 93
  CHECK_SYNC("hl_matrix_row_min failed");
}

/**
 * @brief   matrix column operator.
 */
L
liaogang 已提交
94 95 96
template <class Agg>
__global__ void KeMatrixColumnOp(
    Agg agg, real *E, real *Sum, int dimM, int dimN) {
Z
zhangjinchao01 已提交
97 98 99 100 101 102 103 104 105 106
  int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
  real tmp = agg.init();
  if (rowIdx < dimN) {
    for (int index = 0; index < dimM; index++) {
      tmp = agg(tmp, E[dimN * index + rowIdx]);
    }
    Sum[rowIdx] = tmp;
  }
}

L
liaogang 已提交
107 108 109 110 111 112
template <class Agg, int blockDimX, int blockDimY>
__global__ void KeMatrixColumnOp_S(
    Agg agg, real *E, real *Sum, int dimM, int dimN) {
  __shared__ real _sum[blockDimX * blockDimY];
  int rowIdx = blockIdx.x * blockDim.x + threadIdx.x;
  int index = threadIdx.y;
Z
zhangjinchao01 已提交
113 114 115 116 117 118 119 120

  real tmp = agg.init();
  if (rowIdx < dimN) {
    for (; index < dimM;) {
      tmp = agg(tmp, E[dimN * index + rowIdx]);
      index += blockDimY;
    }
  }
L
liaogang 已提交
121
  _sum[threadIdx.x + threadIdx.y * blockDimX] = tmp;
Z
zhangjinchao01 已提交
122 123 124
  __syncthreads();

  if (rowIdx < dimN) {
L
liaogang 已提交
125
    if (threadIdx.y == 0) {
Z
zhangjinchao01 已提交
126
      real tmp = agg.init();
L
liaogang 已提交
127 128
      for (int i = 0; i < blockDimY; i++) {
        tmp = agg(tmp, _sum[threadIdx.x + i * blockDimX]);
Z
zhangjinchao01 已提交
129 130 131 132 133 134 135
      }
      Sum[rowIdx] = tmp;
    }
  }
}

template <class Agg>
L
liaogang 已提交
136
void hl_matrix_column_op(Agg agg, real *A_d, real *C_d, int dimM, int dimN) {
Z
zhangjinchao01 已提交
137
  if (dimN >= 8192) {
L
liaogang 已提交
138
    int blocksX = (dimN + 128 - 1) / 128;
Z
zhangjinchao01 已提交
139 140 141
    int blocksY = 1;
    dim3 threads(128, 1);
    dim3 grid(blocksX, blocksY);
L
liaogang 已提交
142 143
    KeMatrixColumnOp<Agg><<<grid, threads, 0, STREAM_DEFAULT>>>(
        agg, A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
144
  } else {
L
liaogang 已提交
145
    int blocksX = (dimN + 32 - 1) / 32;
Z
zhangjinchao01 已提交
146 147 148
    int blocksY = 1;
    dim3 threads(32, 32);
    dim3 grid(blocksX, blocksY);
L
liaogang 已提交
149 150
    KeMatrixColumnOp_S<Agg, 32, 32><<<grid, threads, 0, STREAM_DEFAULT>>>(
        agg, A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
151 152 153 154 155 156 157 158 159
  }

  return;
}

void hl_matrix_column_sum(real *A_d, real *C_d, int dimM, int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
160
  hl_matrix_column_op(aggregate::sum(), A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
161 162 163 164 165 166 167 168

  CHECK_SYNC("hl_matrix_column_sum failed");
}

void hl_matrix_column_max(real *A_d, real *C_d, int dimM, int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
169
  hl_matrix_column_op(aggregate::max(), A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
170 171 172 173 174 175 176 177

  CHECK_SYNC("hl_matrix_column_max failed");
}

void hl_matrix_column_min(real *A_d, real *C_d, int dimM, int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
178
  hl_matrix_column_op(aggregate::min(), A_d, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
179 180 181 182 183 184 185 186

  CHECK_SYNC("hl_matrix_column_min failed");
}

template <int blockSize>
__global__ void KeVectorSum(real *E, real *Sum, int dimM) {
  __shared__ double sum_s[blockSize];
  int tid = threadIdx.x;
L
liaogang 已提交
187
  int index = blockIdx.y * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
188 189 190 191

  sum_s[tid] = 0.0f;
  while (index < dimM) {
    sum_s[tid] += E[index];
L
liaogang 已提交
192
    index += blockDim.x * gridDim.y;
Z
zhangjinchao01 已提交
193 194 195
  }
  __syncthreads();

L
liaogang 已提交
196
  for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
Z
zhangjinchao01 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    if (tid < stride) {
      sum_s[tid] += sum_s[tid + stride];
    }
    __syncthreads();
  }
  __syncthreads();

  if (tid == 0) {
    Sum[blockIdx.y] = sum_s[0];
  }
}

void hl_vector_sum(real *A_d, real *C_h, int dimM) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_h);

  int blockSize = 128;
  int gridSize = 128;
  int blocksX = 1;
  int blocksY = gridSize;
  dim3 threads(blockSize, 1);
  dim3 grid(blocksX, blocksY);

L
liaogang 已提交
220
  struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
Z
zhangjinchao01 已提交
221
  hl_event_t hl_event = &hl_event_st;
L
liaogang 已提交
222 223
  while (!hl_cuda_event_is_ready(hl_event)) {
  }
Z
zhangjinchao01 已提交
224

L
liaogang 已提交
225 226 227 228
  KeVectorSum<128><<<grid, threads, 0, STREAM_DEFAULT>>>(
      A_d, t_resource.gpu_mem, dimM);
  KeVectorSum<128><<<1, threads, 0, STREAM_DEFAULT>>>(
      t_resource.gpu_mem, t_resource.cpu_mem, 128);
Z
zhangjinchao01 已提交
229 230 231 232

  hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
  hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);

L
liaogang 已提交
233 234
  hl_stream_synchronize(HPPL_STREAM_DEFAULT);
  cudaError_t err = (cudaError_t)hl_get_device_last_error();
L
liaogang 已提交
235 236
  CHECK_EQ(cudaSuccess, err) << "CUDA error: "
                             << hl_get_device_error_string((size_t)err);
Z
zhangjinchao01 已提交
237 238 239 240 241 242
}

template <int blockSize>
__global__ void KeVectorAbsSum(real *E, real *Sum, int dimM) {
  __shared__ double sum_s[blockSize];
  int tid = threadIdx.x;
L
liaogang 已提交
243
  int index = blockIdx.y * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
244 245 246 247

  sum_s[tid] = 0.0f;
  while (index < dimM) {
    sum_s[tid] += abs(E[index]);
L
liaogang 已提交
248
    index += blockDim.x * gridDim.y;
Z
zhangjinchao01 已提交
249 250 251
  }
  __syncthreads();

L
liaogang 已提交
252
  for (int stride = blockSize / 2; stride > 0; stride = stride / 2) {
Z
zhangjinchao01 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    if (tid < stride) {
      sum_s[tid] += sum_s[tid + stride];
    }
    __syncthreads();
  }
  __syncthreads();

  if (tid == 0) {
    Sum[blockIdx.y] = sum_s[0];
  }
}

void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_h);

  int blockSize = 128;
  int gridSize = 128;
  int blocksX = 1;
  int blocksY = gridSize;
  dim3 threads(blockSize, 1);
  dim3 grid(blocksX, blocksY);

L
liaogang 已提交
276
  struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
Z
zhangjinchao01 已提交
277
  hl_event_t hl_event = &hl_event_st;
L
liaogang 已提交
278 279
  while (!hl_cuda_event_is_ready(hl_event)) {
  }
Z
zhangjinchao01 已提交
280

L
liaogang 已提交
281 282 283 284
  KeVectorAbsSum<128><<<grid, threads, 0, STREAM_DEFAULT>>>(
      A_d, t_resource.gpu_mem, dimM);
  KeVectorAbsSum<128><<<1, threads, 0, STREAM_DEFAULT>>>(
      t_resource.gpu_mem, t_resource.cpu_mem, 128);
Z
zhangjinchao01 已提交
285 286 287 288

  hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
  hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);

L
liaogang 已提交
289 290
  hl_stream_synchronize(HPPL_STREAM_DEFAULT);
  cudaError_t err = (cudaError_t)hl_get_device_last_error();
L
liaogang 已提交
291 292
  CHECK_EQ(cudaSuccess, err) << "CUDA error: "
                             << hl_get_device_error_string((size_t)err);
Z
zhangjinchao01 已提交
293
}