hl_top_k.cu 14.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "hl_sparse.ph"
L
liaogang 已提交
17
#include "hl_top_k.h"
Z
zhangjinchao01 已提交
18 19 20 21 22
#include "paddle/utils/Logging.h"

// using namespace hppl;

struct Pair {
L
liaogang 已提交
23
  __device__ __forceinline__ Pair() {}
Z
zhangjinchao01 已提交
24

L
liaogang 已提交
25
  __device__ __forceinline__ Pair(real value, int id) : v_(value), id_(id) {}
Z
zhangjinchao01 已提交
26

L
liaogang 已提交
27
  __device__ __forceinline__ void set(real value, int id) {
Z
zhangjinchao01 已提交
28 29 30 31
    v_ = value;
    id_ = id;
  }

L
liaogang 已提交
32
  __device__ __forceinline__ void operator=(const Pair& in) {
Z
zhangjinchao01 已提交
33 34 35 36
    v_ = in.v_;
    id_ = in.id_;
  }

L
liaogang 已提交
37
  __device__ __forceinline__ bool operator<(const real value) const {
Z
zhangjinchao01 已提交
38 39 40
    return (v_ < value);
  }

L
liaogang 已提交
41
  __device__ __forceinline__ bool operator<(const Pair& in) const {
Z
zhangjinchao01 已提交
42 43 44
    return (v_ < in.v_) || ((v_ == in.v_) && (id_ > in.id_));
  }

L
liaogang 已提交
45
  __device__ __forceinline__ bool operator>(const Pair& in) const {
Z
zhangjinchao01 已提交
46 47 48 49 50 51 52
    return (v_ > in.v_) || ((v_ == in.v_) && (id_ < in.id_));
  }

  real v_;
  int id_;
};

L
liaogang 已提交
53 54 55
__device__ __forceinline__ void addTo(Pair topK[],
                                      const Pair& p,
                                      int beamSize) {
Z
zhangjinchao01 已提交
56 57 58 59 60 61 62 63 64 65 66
  for (int k = beamSize - 2; k >= 0; k--) {
    if (topK[k] < p) {
      topK[k + 1] = topK[k];
    } else {
      topK[k + 1] = p;
      return;
    }
  }
  topK[0] = p;
}

L
liaogang 已提交
67 68
template <int beamSize>
__device__ __forceinline__ void addTo(Pair topK[], const Pair& p) {
Z
zhangjinchao01 已提交
69 70 71 72 73 74 75 76 77 78 79
  for (int k = beamSize - 2; k >= 0; k--) {
    if (topK[k] < p) {
      topK[k + 1] = topK[k];
    } else {
      topK[k + 1] = p;
      return;
    }
  }
  topK[0] = p;
}

L
liaogang 已提交
80 81 82
template <int blockSize>
__device__ __forceinline__ void getTopK(
    Pair topK[], real* src, int idx, int dim, int beamSize) {
Z
zhangjinchao01 已提交
83 84 85 86 87 88 89 90 91
  while (idx < dim) {
    if (topK[beamSize - 1] < src[idx]) {
      Pair tmp(src[idx], idx);
      addTo(topK, tmp, beamSize);
    }
    idx += blockSize;
  }
}

L
liaogang 已提交
92 93 94
template <int blockSize>
__device__ __forceinline__ void getTopK(
    Pair topK[], real* src, int idx, int dim, const Pair& max, int beamSize) {
Z
zhangjinchao01 已提交
95 96 97 98 99 100 101 102 103 104 105
  while (idx < dim) {
    if (topK[beamSize - 1] < src[idx]) {
      Pair tmp(src[idx], idx);
      if (tmp < max) {
        addTo(topK, tmp, beamSize);
      }
    }
    idx += blockSize;
  }
}

L
liaogang 已提交
106 107 108
template <int blockSize>
__device__ __forceinline__ void getTopK(
    Pair topK[], real* val, int* col, int idx, int dim, int beamSize) {
Z
zhangjinchao01 已提交
109 110 111 112 113 114 115 116 117
  while (idx < dim) {
    if (topK[beamSize - 1] < val[idx]) {
      Pair tmp(val[idx], col[idx]);
      addTo(topK, tmp, beamSize);
    }
    idx += blockSize;
  }
}

L
liaogang 已提交
118 119 120 121 122 123 124 125
template <int blockSize>
__device__ __forceinline__ void getTopK(Pair topK[],
                                        real* val,
                                        int* col,
                                        int idx,
                                        int dim,
                                        const Pair& max,
                                        int beamSize) {
Z
zhangjinchao01 已提交
126 127 128 129 130 131 132 133 134 135 136
  while (idx < dim) {
    if (topK[beamSize - 1] < val[idx]) {
      Pair tmp(val[idx], col[idx]);
      if (tmp < max) {
        addTo(topK, tmp, beamSize);
      }
    }
    idx += blockSize;
  }
}

L
liaogang 已提交
137 138 139 140 141 142 143 144 145 146
template <int maxLength, int blockSize>
__device__ __forceinline__ void threadGetTopK(Pair topK[],
                                              int& beam,
                                              int beamSize,
                                              real* src,
                                              bool& firstStep,
                                              bool& isEmpty,
                                              Pair& max,
                                              int dim,
                                              const int tid) {
Z
zhangjinchao01 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160
  if (beam > 0) {
    int length = beam < beamSize ? beam : beamSize;
    if (firstStep) {
      firstStep = false;
      getTopK<blockSize>(topK, src, tid, dim, length);
    } else {
      for (int k = 0; k < maxLength; k++) {
        if (k < maxLength - beam) {
          topK[k] = topK[k + beam];
        } else {
          topK[k].set(-HL_FLOAT_MAX, -1);
        }
      }
      if (!isEmpty) {
L
liaogang 已提交
161
        getTopK<blockSize>(topK + maxLength - beam, src, tid, dim, max, length);
Z
zhangjinchao01 已提交
162 163 164 165 166 167 168 169 170
      }
    }

    max = topK[maxLength - 1];
    if (max.id_ == -1) isEmpty = true;
    beam = 0;
  }
}

L
liaogang 已提交
171 172 173 174 175 176 177 178 179 180 181
template <int maxLength, int blockSize>
__device__ __forceinline__ void threadGetTopK(Pair topK[],
                                              int& beam,
                                              int beamSize,
                                              real* val,
                                              int* col,
                                              bool& firstStep,
                                              bool& isEmpty,
                                              Pair& max,
                                              int dim,
                                              const int tid) {
Z
zhangjinchao01 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195
  if (beam > 0) {
    int length = beam < beamSize ? beam : beamSize;
    if (firstStep) {
      firstStep = false;
      getTopK<blockSize>(topK, val, col, tid, dim, length);
    } else {
      for (int k = 0; k < maxLength; k++) {
        if (k < maxLength - beam) {
          topK[k] = topK[k + beam];
        } else {
          topK[k].set(-HL_FLOAT_MAX, -1);
        }
      }
      if (!isEmpty) {
L
liaogang 已提交
196 197
        getTopK<blockSize>(
            topK + maxLength - beam, val, col, tid, dim, max, length);
Z
zhangjinchao01 已提交
198 199 200 201 202 203 204 205 206
      }
    }

    max = topK[maxLength - 1];
    if (max.id_ == -1) isEmpty = true;
    beam = 0;
  }
}

L
liaogang 已提交
207 208 209 210 211 212 213 214 215 216
template <int maxLength, int blockSize>
__device__ __forceinline__ void blockReduce(Pair* shTopK,
                                            int* maxId,
                                            Pair topK[],
                                            real** topVal,
                                            int** topIds,
                                            int& beam,
                                            int& beamSize,
                                            const int tid,
                                            const int warp) {
Z
zhangjinchao01 已提交
217 218 219 220 221 222 223 224 225 226
  while (true) {
    __syncthreads();
    if (tid < blockSize / 2) {
      if (shTopK[tid] < shTopK[tid + blockSize / 2]) {
        maxId[tid] = tid + blockSize / 2;
      } else {
        maxId[tid] = tid;
      }
    }
    __syncthreads();
L
liaogang 已提交
227
    for (int stride = blockSize / 4; stride > 0; stride = stride / 2) {
Z
zhangjinchao01 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
      if (tid < stride) {
        if (shTopK[maxId[tid]] < shTopK[maxId[tid + stride]]) {
          maxId[tid] = maxId[tid + stride];
        }
      }
      __syncthreads();
    }
    __syncthreads();

    if (tid == 0) {
      **topVal = shTopK[maxId[0]].v_;
      **topIds = shTopK[maxId[0]].id_;
      (*topVal)++;
      (*topIds)++;
    }
    if (tid == maxId[0]) beam++;
    if (--beamSize == 0) break;
    __syncthreads();

C
chengduoZH 已提交
247 248 249
    unsigned mask = 0u;
    // CREATE_SHFL_MASK(mask, tid < len);

Z
zhangjinchao01 已提交
250 251 252 253 254 255
    if (tid == maxId[0]) {
      if (beam < maxLength) {
        shTopK[tid] = topK[beam];
      }
    }
    if (maxId[0] / 32 == warp) {
C
chengduoZH 已提交
256
      if (__shfl_sync(mask, beam, (maxId[0]) % 32, 32) == maxLength) break;
Z
zhangjinchao01 已提交
257 258 259 260 261 262 263 264 265 266 267 268
    }
  }
}

/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top maxLength value;
 * 2. merge to shTopK, block reduce and get max value;
 * 3. go to the second setp, until one thread's topK value is null;
 * 4. go to the first setp, until get the topK value.
 */
L
liaogang 已提交
269 270 271 272 273 274
template <int maxLength, int blockSize>
__global__ void KeMatrixTopK(real* topVal,
                             int ldv,
                             int* topIds,
                             real* src,
                             int lds,
Z
zhangjinchao01 已提交
275 276 277 278 279 280 281 282 283 284
                             int dim,
                             int beamSize) {
  __shared__ Pair shTopK[blockSize];
  __shared__ int maxId[blockSize / 2];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;
  src += blockIdx.x * lds;
  topVal += blockIdx.x * ldv;
  topIds += blockIdx.x * beamSize;

L
liaogang 已提交
285
  Pair topK[maxLength];  // NOLINT
Z
zhangjinchao01 已提交
286 287 288 289 290 291 292 293 294
  int beam = maxLength;
  Pair max;
  bool isEmpty = false;
  bool firstStep = true;

  for (int k = 0; k < maxLength; k++) {
    topK[k].set(-HL_FLOAT_MAX, -1);
  }
  while (beamSize) {
L
liaogang 已提交
295 296
    threadGetTopK<maxLength, blockSize>(
        topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
Z
zhangjinchao01 已提交
297 298

    shTopK[tid] = topK[0];
L
liaogang 已提交
299 300
    blockReduce<maxLength, blockSize>(
        shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
Z
zhangjinchao01 已提交
301 302 303
  }
}

L
liaogang 已提交
304 305 306 307
template <int maxLength, int blockSize>
__global__ void KeSMatrixTopK(real* topVal,
                              int ldv,
                              int* topIds,
Z
zhangjinchao01 已提交
308 309 310 311 312 313 314 315 316 317 318
                              real* val,
                              int* row,
                              int* col,
                              int beamSize) {
  __shared__ Pair shTopK[blockSize];
  __shared__ int maxId[blockSize / 2];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;
  topVal += blockIdx.x * ldv;
  topIds += blockIdx.x * beamSize;

L
liaogang 已提交
319
  Pair topK[maxLength];  // NOLINT
Z
zhangjinchao01 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
  int beam = maxLength;
  Pair max;
  bool isEmpty = false;
  bool firstStep = true;

  int start = row[blockIdx.x];
  int end = row[blockIdx.x + 1];
  int dim = end - start;
  val += start;
  col += start;

  if (beamSize > dim) {
    // if the number of values to sort are less than the output size,
    // use -1 to indicate the end of valid sorted values.
    if (tid == 0) {
      topIds[dim] = -1;
    }

    beamSize = dim;
  }

  for (int k = 0; k < maxLength; k++) {
    topK[k].set(-HL_FLOAT_MAX, -1);
  }
  while (beamSize) {
L
liaogang 已提交
345 346
    threadGetTopK<maxLength, blockSize>(
        topK, beam, beamSize, val, col, firstStep, isEmpty, max, dim, tid);
Z
zhangjinchao01 已提交
347 348

    shTopK[tid] = topK[0];
L
liaogang 已提交
349 350
    blockReduce<maxLength, blockSize>(
        shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
Z
zhangjinchao01 已提交
351 352 353
  }
}

L
liaogang 已提交
354 355 356 357 358
void hl_matrix_top_k(real* topVal,
                     int ldv,
                     int* topIds,
                     real* src,
                     int lds,
Z
zhangjinchao01 已提交
359 360 361 362 363 364 365 366 367 368 369
                     int dim,
                     int beamSize,
                     int numSamples) {
  CHECK_NOTNULL(topVal);
  CHECK_NOTNULL(topIds);
  CHECK_NOTNULL(src);

  if (beamSize > dim) beamSize = dim;

  dim3 threads(256, 1);
  dim3 grid(numSamples, 1);
L
liaogang 已提交
370 371
  KeMatrixTopK<5, 256><<<grid, threads, 0, STREAM_DEFAULT>>>(
      topVal, ldv, topIds, src, lds, dim, beamSize);
Z
zhangjinchao01 已提交
372 373 374 375

  CHECK_SYNC("hl_matrix_top_k failed");
}

L
liaogang 已提交
376 377 378
void hl_sparse_matrix_top_k(real* topVal,
                            int ldv,
                            int* topIds,
Z
zhangjinchao01 已提交
379 380 381 382 383 384
                            hl_sparse_matrix_s src,
                            int beamSize,
                            int numSamples) {
  CHECK_NOTNULL(topVal);
  CHECK_NOTNULL(topIds);
  CHECK_NOTNULL(src);
L
liaogang 已提交
385
  CHECK_EQ(src->format, HL_SPARSE_CSR) << "sparse matrix format error!";
Z
zhangjinchao01 已提交
386 387

  hl_csr_matrix csr = (hl_csr_matrix)src->matrix;
L
liaogang 已提交
388
  if (csr->csr_val == NULL || csr->csr_row == NULL || csr->csr_col == NULL) {
Z
zhangjinchao01 已提交
389 390 391 392 393
    LOG(FATAL) << "parameter src is null!";
  }

  dim3 threads(256, 1);
  dim3 grid(numSamples, 1);
L
liaogang 已提交
394 395
  KeSMatrixTopK<5, 256><<<grid, threads, 0, STREAM_DEFAULT>>>(
      topVal, ldv, topIds, csr->csr_val, csr->csr_row, csr->csr_col, beamSize);
Z
zhangjinchao01 已提交
396 397 398 399

  CHECK_SYNC("hl_sparse_matrix_top_k failed");
}

400 401 402 403 404 405 406 407
/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top maxLength value;
 * 2. merge to shTopK, block reduce and get max value;
 * 3. go to the second setp, until one thread's topK value is null;
 * 4. go to the first setp, until get the topK value.
 */
L
liaogang 已提交
408 409 410 411 412 413
template <int maxLength, int blockSize>
__global__ void KeMatrixTopKClassificationError(real* topVal,
                                                int ldv,
                                                int* topIds,
                                                real* src,
                                                int lds,
414 415 416 417 418 419 420 421 422 423 424 425
                                                int dim,
                                                int beamSize,
                                                int* label,
                                                real* recResult) {
  __shared__ Pair shTopK[blockSize];
  __shared__ int maxId[blockSize / 2];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;
  src += blockIdx.x * lds;
  topVal += blockIdx.x * ldv;
  topIds += blockIdx.x * beamSize;

L
liaogang 已提交
426
  Pair topK[maxLength];  // NOLINT
427 428 429 430 431 432 433 434 435 436 437
  int beam = maxLength;
  Pair max;
  bool isEmpty = false;
  bool firstStep = true;
  int topkSize = beamSize;

  for (int k = 0; k < maxLength; k++) {
    topK[k].set(-HL_FLOAT_MAX, -1);
  }

  while (beamSize) {
L
liaogang 已提交
438 439
    threadGetTopK<maxLength, blockSize>(
        topK, beam, beamSize, src, firstStep, isEmpty, max, dim, tid);
440 441

    shTopK[tid] = topK[0];
L
liaogang 已提交
442 443
    blockReduce<maxLength, blockSize>(
        shTopK, maxId, topK, &topVal, &topIds, beam, beamSize, tid, warp);
444 445 446 447 448
  }

  __syncthreads();
  if (tid == 0) {
    for (int i = 0; i < topkSize; i++) {
L
liaogang 已提交
449 450 451 452 453
      if (*--topIds == label[blockIdx.x]) {
        recResult[blockIdx.x] = 0;
        break;
      }
      recResult[blockIdx.x] = 1.0f;
454 455 456 457
    }
  }
}

L
liaogang 已提交
458 459 460 461 462 463 464 465 466 467
void hl_matrix_classification_error(real* topVal,
                                    int ldv,
                                    int* topIds,
                                    real* src,
                                    int lds,
                                    int dim,
                                    int topkSize,
                                    int numSamples,
                                    int* label,
                                    real* recResult) {
468 469 470 471 472 473 474 475
  CHECK_NOTNULL(topVal);
  CHECK_NOTNULL(topIds);
  CHECK_NOTNULL(src);

  if (topkSize > dim) topkSize = dim;

  dim3 threads(256, 1);
  dim3 grid(numSamples, 1);
L
liaogang 已提交
476 477
  KeMatrixTopKClassificationError<5, 256><<<grid, threads, 0, STREAM_DEFAULT>>>(
      topVal, ldv, topIds, src, lds, dim, topkSize, label, recResult);
478 479 480

  CHECK_SYNC("hl_matrix_top_k classification error failed");
}