graph_sample_neighbors_kernel.cu 17.1 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/transform.h>

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#else
#include <cuda_runtime.h>
#include <curand_kernel.h>
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
33
#include "paddle/phi/kernels/graph_sample_neighbors_kernel.h"
S
Siming Dai 已提交
34 35 36 37 38 39

namespace phi {

template <typename T>
struct DegreeFunctor {
  const T* col_ptr;
40 41 42 43 44
  int64_t len_col_ptr;
  HOSTDEVICE explicit inline DegreeFunctor(const T* x, int64_t len_col_ptr) {
    this->col_ptr = x;
    this->len_col_ptr = len_col_ptr;
  }
S
Siming Dai 已提交
45
  HOSTDEVICE inline int operator()(T i) const {
46
    return i > len_col_ptr - 1 ? 0 : col_ptr[i + 1] - col_ptr[i];
S
Siming Dai 已提交
47 48 49 50 51 52 53 54 55 56
  }
};

struct MaxFunctor {
  int cap;
  HOSTDEVICE explicit inline MaxFunctor(int cap) { this->cap = cap; }
  HOSTDEVICE inline int operator()(int x) const {
    if (x > cap) {
      return cap;
    }
57
    return x >= 0 ? x : 0;
S
Siming Dai 已提交
58 59 60 61 62 63 64
  }
};

template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void SampleKernel(const uint64_t rand_seed,
                             int k,
                             const int64_t num_nodes,
65
                             const int64_t len_col_ptr,
S
Siming Dai 已提交
66 67 68
                             const T* nodes,
                             const T* row,
                             const T* col_ptr,
69
                             const T* eids,
S
Siming Dai 已提交
70
                             T* output,
71
                             T* output_eids,
S
Siming Dai 已提交
72
                             int* output_ptr,
73
                             bool return_eids) {
S
Siming Dai 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  assert(blockDim.x == WARP_SIZE);
  assert(blockDim.y == BLOCK_WARPS);

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_nodes);
#ifdef PADDLE_WITH_HIP
  hiprandState rng;
  hiprand_init(rand_seed * gridDim.x + blockIdx.x,
               threadIdx.y * WARP_SIZE + threadIdx.x,
               0,
               &rng);
#else
  curandState rng;
  curand_init(rand_seed * gridDim.x + blockIdx.x,
              threadIdx.y * WARP_SIZE + threadIdx.x,
              0,
              &rng);
#endif

  while (out_row < last_row) {
    T node = nodes[out_row];
96 97 98 99
    if (node > len_col_ptr - 1) {
      out_row += BLOCK_WARPS;
      continue;
    }
S
Siming Dai 已提交
100 101 102 103 104 105 106
    T in_row_start = col_ptr[node];
    int deg = col_ptr[node + 1] - in_row_start;
    int out_row_start = output_ptr[out_row];

    if (deg <= k) {
      for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
        output[out_row_start + idx] = row[in_row_start + idx];
107 108 109
        if (return_eids) {
          output_eids[out_row_start + idx] = eids[in_row_start + idx];
        }
S
Siming Dai 已提交
110 111 112
      }
    } else {
      for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
113
        output[out_row_start + idx] = idx;
S
Siming Dai 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126
      }
#ifdef PADDLE_WITH_CUDA
      __syncwarp();
#endif

      for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) {
#ifdef PADDLE_WITH_HIP
        const int num = hiprand(&rng) % (idx + 1);
#else
        const int num = curand(&rng) % (idx + 1);
#endif
        if (num < k) {
          atomicMax(reinterpret_cast<unsigned int*>(  // NOLINT
127
                        output + out_row_start + num),
S
Siming Dai 已提交
128 129 130 131 132 133 134 135
                    static_cast<unsigned int>(idx));  // NOLINT
        }
      }
#ifdef PADDLE_WITH_CUDA
      __syncwarp();
#endif

      for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
136
        T perm_idx = output[out_row_start + idx] + in_row_start;
S
Siming Dai 已提交
137
        output[out_row_start + idx] = row[perm_idx];
138 139 140
        if (return_eids) {
          output_eids[out_row_start + idx] = eids[perm_idx];
        }
S
Siming Dai 已提交
141 142 143 144 145 146 147 148 149 150
      }
    }

    out_row += BLOCK_WARPS;
  }
}

template <typename T, typename Context>
int GetTotalSampleNum(const thrust::device_ptr<const T> input,
                      const T* col_ptr,
151
                      int64_t len_col_ptr,
S
Siming Dai 已提交
152 153 154
                      thrust::device_ptr<int> output_count,
                      int sample_size,
                      int bs) {
155 156
  thrust::transform(
      input, input + bs, output_count, DegreeFunctor<T>(col_ptr, len_col_ptr));
S
Siming Dai 已提交
157 158 159 160 161 162 163 164 165 166 167 168
  if (sample_size >= 0) {
    thrust::transform(
        output_count, output_count + bs, output_count, MaxFunctor(sample_size));
  }
  int total_sample_num = thrust::reduce(output_count, output_count + bs);
  return total_sample_num;
}

template <typename T, typename Context>
void SampleNeighbors(const Context& dev_ctx,
                     const T* row,
                     const T* col_ptr,
169
                     const T* eids,
S
Siming Dai 已提交
170 171 172
                     const thrust::device_ptr<const T> input,
                     thrust::device_ptr<T> output,
                     thrust::device_ptr<int> output_count,
173
                     thrust::device_ptr<T> output_eids,
S
Siming Dai 已提交
174 175
                     int sample_size,
                     int bs,
176
                     int total_sample_num,
177
                     int64_t len_col_ptr,
178
                     bool return_eids) {
S
Siming Dai 已提交
179 180 181 182 183 184 185 186 187 188
  thrust::device_vector<int> output_ptr;
  output_ptr.resize(bs);
  thrust::exclusive_scan(
      output_count, output_count + bs, output_ptr.begin(), 0);

  constexpr int WARP_SIZE = 32;
  constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
  constexpr int TILE_SIZE = BLOCK_WARPS * 16;
  const dim3 block(WARP_SIZE, BLOCK_WARPS);
  const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
189 190 191 192 193
  SampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
      <<<grid, block, 0, dev_ctx.stream()>>>(
          0,
          sample_size,
          bs,
194
          len_col_ptr,
195 196 197 198 199 200 201 202
          thrust::raw_pointer_cast(input),
          row,
          col_ptr,
          eids,
          thrust::raw_pointer_cast(output),
          thrust::raw_pointer_cast(output_eids),
          thrust::raw_pointer_cast(output_ptr.data()),
          return_eids);
S
Siming Dai 已提交
203 204
}

205
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
S
Siming Dai 已提交
206 207 208
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
                                        int k,
                                        const int64_t num_rows,
209
                                        const int64_t len_col_ptr,
S
Siming Dai 已提交
210 211 212
                                        const T* in_rows,
                                        T* src,
                                        const T* dst_count) {
213 214 215 216 217 218
  assert(blockDim.x == WARP_SIZE);
  assert(blockDim.y == BLOCK_WARPS);

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
S
Siming Dai 已提交
219 220 221 222 223 224 225 226 227
#ifdef PADDLE_WITH_HIP
  hiprandState rng;
  hiprand_init(
      rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#else
  curandState rng;
  curand_init(
      rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif
228 229

  while (out_row < last_row) {
S
Siming Dai 已提交
230
    const T row = in_rows[out_row];
231 232 233 234
    if (row > len_col_ptr - 1) {
      out_row += BLOCK_WARPS;
      continue;
    }
S
Siming Dai 已提交
235 236 237 238 239 240 241 242 243
    const T in_row_start = dst_count[row];
    const int deg = dst_count[row + 1] - in_row_start;
    int split;
    if (k < deg) {
      if (deg < 2 * k) {
        split = k;
      } else {
        split = deg - k;
      }
244
      for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) {
S
Siming Dai 已提交
245 246 247 248 249 250 251 252 253 254 255
#ifdef PADDLE_WITH_HIP
        const int num = hiprand(&rng) % (idx + 1);
#else
        const int num = curand(&rng) % (idx + 1);
#endif
        src[in_row_start + idx] = static_cast<T>(
            atomicExch(reinterpret_cast<unsigned long long int*>(  // NOLINT
                           src + in_row_start + num),
                       static_cast<unsigned long long int>(  //  NOLINT
                           src[in_row_start + idx])));
      }
256 257 258
#ifdef PADDLE_WITH_CUDA
      __syncwarp();
#endif
S
Siming Dai 已提交
259
    }
260
    out_row += BLOCK_WARPS;
S
Siming Dai 已提交
261 262 263 264 265 266 267 268 269
  }
}

template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void GatherEdge(int k,
                           int64_t num_rows,
                           const T* in_rows,
                           const T* src,
                           const T* dst_count,
270
                           const T* eids,
S
Siming Dai 已提交
271
                           T* outputs,
272
                           T* output_eids,
S
Siming Dai 已提交
273
                           int* output_ptr,
274 275
                           T* perm_data,
                           bool return_eids) {
S
Siming Dai 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
  assert(blockDim.x == WARP_SIZE);
  assert(blockDim.y == BLOCK_WARPS);

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  while (out_row < last_row) {
    const T row = in_rows[out_row];
    const T in_row_start = dst_count[row];
    const int deg = dst_count[row + 1] - in_row_start;
    const T out_row_start = output_ptr[out_row];

    if (deg <= k) {
      for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
291 292 293 294
        outputs[out_row_start + idx] = src[in_row_start + idx];
        if (return_eids) {
          output_eids[out_row_start + idx] = eids[in_row_start + idx];
        }
S
Siming Dai 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
      }
    } else {
      int split = k;
      int begin, end;
      if (deg < 2 * k) {
        begin = 0;
        end = k;
      } else {
        begin = deg - k;
        end = deg;
      }

      for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) {
        outputs[out_row_start + idx - begin] =
            src[perm_data[in_row_start + idx]];
310 311 312 313
        if (return_eids) {
          output_eids[out_row_start + idx - begin] =
              eids[perm_data[in_row_start + idx]];
        }
S
Siming Dai 已提交
314 315 316 317 318 319 320 321 322 323
      }
    }
    out_row += BLOCK_WARPS;
  }
}

template <typename T, typename Context>
void FisherYatesSampleNeighbors(const Context& dev_ctx,
                                const T* row,
                                const T* col_ptr,
324
                                const T* eids,
S
Siming Dai 已提交
325 326 327 328
                                T* perm_data,
                                const thrust::device_ptr<const T> input,
                                thrust::device_ptr<T> output,
                                thrust::device_ptr<int> output_count,
329
                                thrust::device_ptr<T> output_eids,
S
Siming Dai 已提交
330 331
                                int sample_size,
                                int bs,
332
                                int total_sample_num,
333
                                int64_t len_col_ptr,
334
                                bool return_eids) {
S
Siming Dai 已提交
335 336 337 338 339
  thrust::device_vector<int> output_ptr;
  output_ptr.resize(bs);
  thrust::exclusive_scan(
      output_count, output_count + bs, output_ptr.begin(), 0);

340 341 342 343 344
  constexpr int WARP_SIZE = 32;
  constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
  constexpr int TILE_SIZE = BLOCK_WARPS * 16;
  const dim3 block(WARP_SIZE, BLOCK_WARPS);
  const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
S
Siming Dai 已提交
345

346 347 348 349
  FisherYatesSampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
      <<<grid, block, 0, dev_ctx.stream()>>>(0,
                                             sample_size,
                                             bs,
350
                                             len_col_ptr,
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
                                             thrust::raw_pointer_cast(input),
                                             perm_data,
                                             col_ptr);

  GatherEdge<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
      <<<grid, block, 0, dev_ctx.stream()>>>(
          sample_size,
          bs,
          thrust::raw_pointer_cast(input),
          row,
          col_ptr,
          eids,
          thrust::raw_pointer_cast(output),
          thrust::raw_pointer_cast(output_eids),
          thrust::raw_pointer_cast(output_ptr.data()),
          perm_data,
          return_eids);
S
Siming Dai 已提交
368 369 370 371 372 373 374 375
}

template <typename T, typename Context>
void GraphSampleNeighborsKernel(
    const Context& dev_ctx,
    const DenseTensor& row,
    const DenseTensor& col_ptr,
    const DenseTensor& x,
376 377
    const paddle::optional<DenseTensor>& eids,
    const paddle::optional<DenseTensor>& perm_buffer,
S
Siming Dai 已提交
378 379 380 381 382 383 384 385 386 387
    int sample_size,
    bool return_eids,
    bool flag_perm_buffer,
    DenseTensor* out,
    DenseTensor* out_count,
    DenseTensor* out_eids) {
  auto* row_data = row.data<T>();
  auto* col_ptr_data = col_ptr.data<T>();
  auto* x_data = x.data<T>();
  int bs = x.dims()[0];
388
  int64_t len_col_ptr = col_ptr.dims()[0];
S
Siming Dai 已提交
389 390 391 392 393 394 395 396

  const thrust::device_ptr<const T> input(x_data);

  out_count->Resize({bs});
  int* out_count_data = dev_ctx.template Alloc<int>(out_count);
  thrust::device_ptr<int> output_count(out_count_data);

  int total_sample_size = GetTotalSampleNum<T, Context>(
397
      input, col_ptr_data, len_col_ptr, output_count, sample_size, bs);
S
Siming Dai 已提交
398 399 400 401 402

  out->Resize({static_cast<int>(total_sample_size)});
  T* out_data = dev_ctx.template Alloc<T>(out);
  thrust::device_ptr<T> output(out_data);

403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
  if (return_eids) {
    auto* eids_data = eids.get_ptr()->data<T>();
    out_eids->Resize({static_cast<int>(total_sample_size)});
    T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
    thrust::device_ptr<T> output_eids(out_eids_data);
    if (!flag_perm_buffer) {
      SampleNeighbors<T, Context>(dev_ctx,
                                  row_data,
                                  col_ptr_data,
                                  eids_data,
                                  input,
                                  output,
                                  output_count,
                                  output_eids,
                                  sample_size,
                                  bs,
                                  total_sample_size,
420
                                  len_col_ptr,
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
                                  return_eids);
    } else {
      DenseTensor perm_buffer_out(perm_buffer->type());
      const auto* p_perm_buffer = perm_buffer.get_ptr();
      perm_buffer_out.ShareDataWith(*p_perm_buffer);
      T* perm_buffer_out_data = perm_buffer_out.template data<T>();
      FisherYatesSampleNeighbors<T, Context>(dev_ctx,
                                             row_data,
                                             col_ptr_data,
                                             eids_data,
                                             perm_buffer_out_data,
                                             input,
                                             output,
                                             output_count,
                                             output_eids,
                                             sample_size,
                                             bs,
                                             total_sample_size,
439
                                             len_col_ptr,
440 441
                                             return_eids);
    }
S
Siming Dai 已提交
442
  } else {
443 444 445 446 447 448 449 450 451 452 453 454 455 456
    // How to set null value for output_eids(thrust::device_ptr<T>)?
    // We use `output` to fill the position of unused output_eids.
    if (!flag_perm_buffer) {
      SampleNeighbors<T, Context>(dev_ctx,
                                  row_data,
                                  col_ptr_data,
                                  nullptr,
                                  input,
                                  output,
                                  output_count,
                                  output,
                                  sample_size,
                                  bs,
                                  total_sample_size,
457
                                  len_col_ptr,
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
                                  return_eids);
    } else {
      DenseTensor perm_buffer_out(perm_buffer->type());
      const auto* p_perm_buffer = perm_buffer.get_ptr();
      perm_buffer_out.ShareDataWith(*p_perm_buffer);
      T* perm_buffer_out_data = perm_buffer_out.template data<T>();
      FisherYatesSampleNeighbors<T, Context>(dev_ctx,
                                             row_data,
                                             col_ptr_data,
                                             nullptr,
                                             perm_buffer_out_data,
                                             input,
                                             output,
                                             output_count,
                                             output,
                                             sample_size,
                                             bs,
                                             total_sample_size,
476
                                             len_col_ptr,
477 478
                                             return_eids);
    }
S
Siming Dai 已提交
479 480 481 482 483 484 485 486 487 488 489
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(graph_sample_neighbors,
                   GPU,
                   ALL_LAYOUT,
                   phi::GraphSampleNeighborsKernel,
                   int,
                   int64_t) {}