graph_sample_neighbors_kernel.cu 17.0 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/transform.h>

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#else
#include <cuda_runtime.h>
#include <curand_kernel.h>
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
33
#include "paddle/phi/kernels/graph_sample_neighbors_kernel.h"
S
Siming Dai 已提交
34 35 36 37 38 39

namespace phi {

template <typename T>
struct DegreeFunctor {
  const T* col_ptr;
40 41 42 43 44
  int64_t len_col_ptr;
  HOSTDEVICE explicit inline DegreeFunctor(const T* x, int64_t len_col_ptr) {
    this->col_ptr = x;
    this->len_col_ptr = len_col_ptr;
  }
S
Siming Dai 已提交
45
  HOSTDEVICE inline int operator()(T i) const {
46
    return i > len_col_ptr - 1 ? 0 : col_ptr[i + 1] - col_ptr[i];
S
Siming Dai 已提交
47 48 49 50 51 52 53 54 55 56
  }
};

struct MaxFunctor {
  int cap;
  HOSTDEVICE explicit inline MaxFunctor(int cap) { this->cap = cap; }
  HOSTDEVICE inline int operator()(int x) const {
    if (x > cap) {
      return cap;
    }
57
    return x >= 0 ? x : 0;
S
Siming Dai 已提交
58 59 60
  }
};

61
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
S
Siming Dai 已提交
62 63 64
__global__ void SampleKernel(const uint64_t rand_seed,
                             int k,
                             const int64_t num_nodes,
65
                             const int64_t len_col_ptr,
S
Siming Dai 已提交
66 67 68
                             const T* nodes,
                             const T* row,
                             const T* col_ptr,
69
                             const T* eids,
S
Siming Dai 已提交
70
                             T* output,
71
                             T* output_eids,
S
Siming Dai 已提交
72
                             int* output_ptr,
73
                             bool return_eids) {
74
  assert(blockDim.x == CTA_SIZE);
S
Siming Dai 已提交
75 76 77 78 79 80 81

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_nodes);
#ifdef PADDLE_WITH_HIP
  hiprandState rng;
  hiprand_init(rand_seed * gridDim.x + blockIdx.x,
82
               threadIdx.y * CTA_SIZE + threadIdx.x,
S
Siming Dai 已提交
83 84 85
               0,
               &rng);
#else
86
  curandStatePhilox4_32_10_t rng;
S
Siming Dai 已提交
87
  curand_init(rand_seed * gridDim.x + blockIdx.x,
88
              threadIdx.y * CTA_SIZE + threadIdx.x,
S
Siming Dai 已提交
89 90 91 92 93 94
              0,
              &rng);
#endif

  while (out_row < last_row) {
    T node = nodes[out_row];
95
    if (node > len_col_ptr - 1) {
96
      out_row += BLOCK_CTAS;
97 98
      continue;
    }
S
Siming Dai 已提交
99 100 101 102 103
    T in_row_start = col_ptr[node];
    int deg = col_ptr[node + 1] - in_row_start;
    int out_row_start = output_ptr[out_row];

    if (deg <= k) {
104
      for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
S
Siming Dai 已提交
105
        output[out_row_start + idx] = row[in_row_start + idx];
106 107 108
        if (return_eids) {
          output_eids[out_row_start + idx] = eids[in_row_start + idx];
        }
S
Siming Dai 已提交
109 110
      }
    } else {
111
      for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
112
        output[out_row_start + idx] = idx;
S
Siming Dai 已提交
113 114
      }
#ifdef PADDLE_WITH_CUDA
115
      __syncthreads();
S
Siming Dai 已提交
116 117
#endif

118
      for (int idx = k + threadIdx.x; idx < deg; idx += CTA_SIZE) {
S
Siming Dai 已提交
119 120 121 122 123 124 125
#ifdef PADDLE_WITH_HIP
        const int num = hiprand(&rng) % (idx + 1);
#else
        const int num = curand(&rng) % (idx + 1);
#endif
        if (num < k) {
          atomicMax(reinterpret_cast<unsigned int*>(  // NOLINT
126
                        output + out_row_start + num),
S
Siming Dai 已提交
127 128 129 130
                    static_cast<unsigned int>(idx));  // NOLINT
        }
      }
#ifdef PADDLE_WITH_CUDA
131
      __syncthreads();
S
Siming Dai 已提交
132 133
#endif

134
      for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
135
        T perm_idx = output[out_row_start + idx] + in_row_start;
S
Siming Dai 已提交
136
        output[out_row_start + idx] = row[perm_idx];
137 138 139
        if (return_eids) {
          output_eids[out_row_start + idx] = eids[perm_idx];
        }
S
Siming Dai 已提交
140 141 142
      }
    }

143
    out_row += BLOCK_CTAS;
S
Siming Dai 已提交
144 145 146 147 148 149
  }
}

template <typename T, typename Context>
int GetTotalSampleNum(const thrust::device_ptr<const T> input,
                      const T* col_ptr,
150
                      int64_t len_col_ptr,
S
Siming Dai 已提交
151 152 153
                      thrust::device_ptr<int> output_count,
                      int sample_size,
                      int bs) {
154 155
  thrust::transform(
      input, input + bs, output_count, DegreeFunctor<T>(col_ptr, len_col_ptr));
S
Siming Dai 已提交
156 157 158 159 160 161 162 163 164 165 166 167
  if (sample_size >= 0) {
    thrust::transform(
        output_count, output_count + bs, output_count, MaxFunctor(sample_size));
  }
  int total_sample_num = thrust::reduce(output_count, output_count + bs);
  return total_sample_num;
}

template <typename T, typename Context>
void SampleNeighbors(const Context& dev_ctx,
                     const T* row,
                     const T* col_ptr,
168
                     const T* eids,
S
Siming Dai 已提交
169 170 171
                     const thrust::device_ptr<const T> input,
                     thrust::device_ptr<T> output,
                     thrust::device_ptr<int> output_count,
172
                     thrust::device_ptr<T> output_eids,
S
Siming Dai 已提交
173 174
                     int sample_size,
                     int bs,
175
                     int total_sample_num,
176
                     int64_t len_col_ptr,
177
                     bool return_eids) {
S
Siming Dai 已提交
178 179 180 181 182
  thrust::device_vector<int> output_ptr;
  output_ptr.resize(bs);
  thrust::exclusive_scan(
      output_count, output_count + bs, output_ptr.begin(), 0);

183 184 185 186
  constexpr int CTA_SIZE = 128;
  constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
  constexpr int TILE_SIZE = BLOCK_CTAS;
  const dim3 block(CTA_SIZE, BLOCK_CTAS);
S
Siming Dai 已提交
187
  const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
188
  SampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
189 190 191 192
      <<<grid, block, 0, dev_ctx.stream()>>>(
          0,
          sample_size,
          bs,
193
          len_col_ptr,
194 195 196 197 198 199 200 201
          thrust::raw_pointer_cast(input),
          row,
          col_ptr,
          eids,
          thrust::raw_pointer_cast(output),
          thrust::raw_pointer_cast(output_eids),
          thrust::raw_pointer_cast(output_ptr.data()),
          return_eids);
S
Siming Dai 已提交
202 203
}

204
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
S
Siming Dai 已提交
205 206 207
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
                                        int k,
                                        const int64_t num_rows,
208
                                        const int64_t len_col_ptr,
S
Siming Dai 已提交
209 210 211
                                        const T* in_rows,
                                        T* src,
                                        const T* dst_count) {
212
  assert(blockDim.x == CTA_SIZE);
213 214 215 216

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
S
Siming Dai 已提交
217 218 219 220 221
#ifdef PADDLE_WITH_HIP
  hiprandState rng;
  hiprand_init(
      rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#else
222
  curandStatePhilox4_32_10_t rng;
S
Siming Dai 已提交
223 224 225
  curand_init(
      rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif
226 227

  while (out_row < last_row) {
S
Siming Dai 已提交
228
    const T row = in_rows[out_row];
229
    if (row > len_col_ptr - 1) {
230
      out_row += BLOCK_CTAS;
231 232
      continue;
    }
S
Siming Dai 已提交
233 234 235 236 237 238 239 240 241
    const T in_row_start = dst_count[row];
    const int deg = dst_count[row + 1] - in_row_start;
    int split;
    if (k < deg) {
      if (deg < 2 * k) {
        split = k;
      } else {
        split = deg - k;
      }
242
      for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) {
S
Siming Dai 已提交
243 244 245 246 247 248 249 250 251 252 253
#ifdef PADDLE_WITH_HIP
        const int num = hiprand(&rng) % (idx + 1);
#else
        const int num = curand(&rng) % (idx + 1);
#endif
        src[in_row_start + idx] = static_cast<T>(
            atomicExch(reinterpret_cast<unsigned long long int*>(  // NOLINT
                           src + in_row_start + num),
                       static_cast<unsigned long long int>(  //  NOLINT
                           src[in_row_start + idx])));
      }
254
#ifdef PADDLE_WITH_CUDA
255
      __syncthreads();
256
#endif
S
Siming Dai 已提交
257
    }
258
    out_row += BLOCK_CTAS;
S
Siming Dai 已提交
259 260 261
  }
}

262
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
S
Siming Dai 已提交
263 264 265 266 267
__global__ void GatherEdge(int k,
                           int64_t num_rows,
                           const T* in_rows,
                           const T* src,
                           const T* dst_count,
268
                           const T* eids,
S
Siming Dai 已提交
269
                           T* outputs,
270
                           T* output_eids,
S
Siming Dai 已提交
271
                           int* output_ptr,
272 273
                           T* perm_data,
                           bool return_eids) {
274
  assert(blockDim.x == CTA_SIZE);
S
Siming Dai 已提交
275 276 277 278 279 280 281 282 283 284 285 286

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  while (out_row < last_row) {
    const T row = in_rows[out_row];
    const T in_row_start = dst_count[row];
    const int deg = dst_count[row + 1] - in_row_start;
    const T out_row_start = output_ptr[out_row];

    if (deg <= k) {
287
      for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
288 289 290 291
        outputs[out_row_start + idx] = src[in_row_start + idx];
        if (return_eids) {
          output_eids[out_row_start + idx] = eids[in_row_start + idx];
        }
S
Siming Dai 已提交
292 293 294 295 296 297 298 299 300 301 302 303
      }
    } else {
      int split = k;
      int begin, end;
      if (deg < 2 * k) {
        begin = 0;
        end = k;
      } else {
        begin = deg - k;
        end = deg;
      }

304
      for (int idx = begin + threadIdx.x; idx < end; idx += CTA_SIZE) {
S
Siming Dai 已提交
305 306
        outputs[out_row_start + idx - begin] =
            src[perm_data[in_row_start + idx]];
307 308 309 310
        if (return_eids) {
          output_eids[out_row_start + idx - begin] =
              eids[perm_data[in_row_start + idx]];
        }
S
Siming Dai 已提交
311 312
      }
    }
313
    out_row += BLOCK_CTAS;
S
Siming Dai 已提交
314 315 316 317 318 319 320
  }
}

template <typename T, typename Context>
void FisherYatesSampleNeighbors(const Context& dev_ctx,
                                const T* row,
                                const T* col_ptr,
321
                                const T* eids,
S
Siming Dai 已提交
322 323 324 325
                                T* perm_data,
                                const thrust::device_ptr<const T> input,
                                thrust::device_ptr<T> output,
                                thrust::device_ptr<int> output_count,
326
                                thrust::device_ptr<T> output_eids,
S
Siming Dai 已提交
327 328
                                int sample_size,
                                int bs,
329
                                int total_sample_num,
330
                                int64_t len_col_ptr,
331
                                bool return_eids) {
S
Siming Dai 已提交
332 333 334 335 336
  thrust::device_vector<int> output_ptr;
  output_ptr.resize(bs);
  thrust::exclusive_scan(
      output_count, output_count + bs, output_ptr.begin(), 0);

337 338 339 340
  constexpr int CTA_SIZE = 128;
  constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
  constexpr int TILE_SIZE = BLOCK_CTAS;
  const dim3 block(CTA_SIZE, BLOCK_CTAS);
341
  const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
S
Siming Dai 已提交
342

343
  FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
344 345 346
      <<<grid, block, 0, dev_ctx.stream()>>>(0,
                                             sample_size,
                                             bs,
347
                                             len_col_ptr,
348 349 350 351
                                             thrust::raw_pointer_cast(input),
                                             perm_data,
                                             col_ptr);

352
  GatherEdge<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
353 354 355 356 357 358 359 360 361 362 363 364
      <<<grid, block, 0, dev_ctx.stream()>>>(
          sample_size,
          bs,
          thrust::raw_pointer_cast(input),
          row,
          col_ptr,
          eids,
          thrust::raw_pointer_cast(output),
          thrust::raw_pointer_cast(output_eids),
          thrust::raw_pointer_cast(output_ptr.data()),
          perm_data,
          return_eids);
S
Siming Dai 已提交
365 366 367 368 369 370 371 372
}

template <typename T, typename Context>
void GraphSampleNeighborsKernel(
    const Context& dev_ctx,
    const DenseTensor& row,
    const DenseTensor& col_ptr,
    const DenseTensor& x,
373 374
    const paddle::optional<DenseTensor>& eids,
    const paddle::optional<DenseTensor>& perm_buffer,
S
Siming Dai 已提交
375 376 377 378 379 380 381 382 383 384
    int sample_size,
    bool return_eids,
    bool flag_perm_buffer,
    DenseTensor* out,
    DenseTensor* out_count,
    DenseTensor* out_eids) {
  auto* row_data = row.data<T>();
  auto* col_ptr_data = col_ptr.data<T>();
  auto* x_data = x.data<T>();
  int bs = x.dims()[0];
385
  int64_t len_col_ptr = col_ptr.dims()[0];
S
Siming Dai 已提交
386 387 388 389 390 391 392 393

  const thrust::device_ptr<const T> input(x_data);

  out_count->Resize({bs});
  int* out_count_data = dev_ctx.template Alloc<int>(out_count);
  thrust::device_ptr<int> output_count(out_count_data);

  int total_sample_size = GetTotalSampleNum<T, Context>(
394
      input, col_ptr_data, len_col_ptr, output_count, sample_size, bs);
S
Siming Dai 已提交
395 396 397 398 399

  out->Resize({static_cast<int>(total_sample_size)});
  T* out_data = dev_ctx.template Alloc<T>(out);
  thrust::device_ptr<T> output(out_data);

400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
  if (return_eids) {
    auto* eids_data = eids.get_ptr()->data<T>();
    out_eids->Resize({static_cast<int>(total_sample_size)});
    T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
    thrust::device_ptr<T> output_eids(out_eids_data);
    if (!flag_perm_buffer) {
      SampleNeighbors<T, Context>(dev_ctx,
                                  row_data,
                                  col_ptr_data,
                                  eids_data,
                                  input,
                                  output,
                                  output_count,
                                  output_eids,
                                  sample_size,
                                  bs,
                                  total_sample_size,
417
                                  len_col_ptr,
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
                                  return_eids);
    } else {
      DenseTensor perm_buffer_out(perm_buffer->type());
      const auto* p_perm_buffer = perm_buffer.get_ptr();
      perm_buffer_out.ShareDataWith(*p_perm_buffer);
      T* perm_buffer_out_data = perm_buffer_out.template data<T>();
      FisherYatesSampleNeighbors<T, Context>(dev_ctx,
                                             row_data,
                                             col_ptr_data,
                                             eids_data,
                                             perm_buffer_out_data,
                                             input,
                                             output,
                                             output_count,
                                             output_eids,
                                             sample_size,
                                             bs,
                                             total_sample_size,
436
                                             len_col_ptr,
437 438
                                             return_eids);
    }
S
Siming Dai 已提交
439
  } else {
440 441 442 443 444 445 446 447 448 449 450 451 452 453
    // How to set null value for output_eids(thrust::device_ptr<T>)?
    // We use `output` to fill the position of unused output_eids.
    if (!flag_perm_buffer) {
      SampleNeighbors<T, Context>(dev_ctx,
                                  row_data,
                                  col_ptr_data,
                                  nullptr,
                                  input,
                                  output,
                                  output_count,
                                  output,
                                  sample_size,
                                  bs,
                                  total_sample_size,
454
                                  len_col_ptr,
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
                                  return_eids);
    } else {
      DenseTensor perm_buffer_out(perm_buffer->type());
      const auto* p_perm_buffer = perm_buffer.get_ptr();
      perm_buffer_out.ShareDataWith(*p_perm_buffer);
      T* perm_buffer_out_data = perm_buffer_out.template data<T>();
      FisherYatesSampleNeighbors<T, Context>(dev_ctx,
                                             row_data,
                                             col_ptr_data,
                                             nullptr,
                                             perm_buffer_out_data,
                                             input,
                                             output,
                                             output_count,
                                             output,
                                             sample_size,
                                             bs,
                                             total_sample_size,
473
                                             len_col_ptr,
474 475
                                             return_eids);
    }
S
Siming Dai 已提交
476 477 478 479 480 481 482 483 484 485
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(graph_sample_neighbors,
                   GPU,
                   ALL_LAYOUT,
                   phi::GraphSampleNeighborsKernel,
                   int,
486 487 488
                   int64_t) {
  kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}