weighted_sample_neighbors_kernel.cu 20.0 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/transform.h>

#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include "cub/cub.cuh"
#endif

#include "math.h"  // NOLINT
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/block_radix_topk.cuh"
#include "paddle/phi/kernels/funcs/random.cuh"
#include "paddle/phi/kernels/weighted_sample_neighbors_kernel.h"
#define SAMPLE_SIZE_THRESHOLD 1024

namespace phi {

#ifdef PADDLE_WITH_CUDA
__device__ __forceinline__ float GenKeyFromWeight(
    const float weight,
    RandomNumGen& rng) {  // NOLINT
  rng.NextValue();
  float u = -rng.RandomUniformFloat(1.0f, 0.5f);
  long long random_num2 = 0;  // NOLINT
  int seed_count = -1;
  do {
    random_num2 = rng.Random64();
    seed_count++;
  } while (!random_num2);
  int one_bit = __clzll(random_num2) + seed_count * 64;
  u *= exp2f(-one_bit);
  float logk = (log1pf(u) / logf(2.0)) * (1 / weight);
  return logk;
}
#endif

template <typename T, bool NeedNeighbor = false>
__global__ void GetSampleCountAndNeighborCountKernel(const T* col_ptr,
                                                     const T* input_nodes,
                                                     int* actual_size,
                                                     int* neighbor_count,
                                                     int sample_size,
                                                     int n) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  if (i >= n) return;
  T nid = input_nodes[i];
  int neighbor_size = static_cast<int>(col_ptr[nid + 1] - col_ptr[nid]);
  // sample_size < 0 means sample all.
  int k = neighbor_size;
  if (sample_size >= 0) {
    k = min(neighbor_size, sample_size);
  }
  actual_size[i] = k;
  if (NeedNeighbor) {
    neighbor_count[i] = (neighbor_size <= sample_size) ? 0 : neighbor_size;
  }
}

#ifdef PADDLE_WITH_CUDA
template <typename T, unsigned int BLOCK_SIZE>
__launch_bounds__(BLOCK_SIZE) __global__
    void WeightedSampleLargeKernel(T* sample_output,
                                   const int* sample_offset,
                                   const int* target_neighbor_offset,
                                   float* weight_keys_buf,
                                   const T* input_nodes,
                                   int input_node_count,
                                   const T* in_rows,
                                   const T* col_ptr,
                                   const float* edge_weight,
                                   const T* eids,
                                   int max_sample_count,
                                   unsigned long long random_seed,  // NOLINT
                                   T* out_eids,
                                   bool return_eids) {
  int i = blockIdx.x;
  if (i >= input_node_count) return;
  int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE;
  T nid = input_nodes[i];
  T start = col_ptr[nid + 1];
  T end = col_ptr[nid];
  int neighbor_count = static_cast<int>(end - start);

  float* weight_keys_local_buff = weight_keys_buf + target_neighbor_offset[i];
  int offset = sample_offset[i];
  if (neighbor_count <= max_sample_count) {
    for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
      sample_output[offset + j] = in_rows[start + j];
      if (return_eids) {
        out_eids[offset + j] = eids[start + j];
      }
    }
  } else {
    RandomNumGen rng(gidx, random_seed);
    for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
      float thread_weight = edge_weight[start + j];
      weight_keys_local_buff[j] =
          static_cast<float>(GenKeyFromWeight(thread_weight, rng));
    }
    __syncthreads();

    float topk_val;
    bool topk_is_unique;

    using BlockRadixSelectT =
        paddle::framework::BlockRadixTopKGlobalMemory<float, BLOCK_SIZE, true>;
    __shared__ typename BlockRadixSelectT::TempStorage share_storage;

    BlockRadixSelectT{share_storage}.radixTopKGetThreshold(
        weight_keys_local_buff,
        max_sample_count,
        neighbor_count,
        topk_val,
        topk_is_unique);
    __shared__ int cnt;

    if (threadIdx.x == 0) {
      cnt = 0;
    }
    __syncthreads();

    // We use atomicAdd 1 operations instead of binaryScan to calculate the
    // write index, since we do not need to keep the relative positions of
    // element.

    if (topk_is_unique) {
      for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
        float key = weight_keys_local_buff[j];
        bool has_topk = (key >= topk_val);

        if (has_topk) {
          int write_index = atomicAdd(&cnt, 1);
          sample_output[offset + write_index] = in_rows[start + j];
          if (return_eids) {
            out_eids[offset + write_index] = eids[start + j];
          }
        }
      }
    } else {
      for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
        float key = weight_keys_local_buff[j];
        bool has_topk = (key > topk_val);

        if (has_topk) {
          int write_index = atomicAdd(&cnt, 1);
          sample_output[offset + write_index] = in_rows[start + j];
          if (return_eids) {
            out_eids[offset + write_index] = eids[start + j];
          }
        }
      }
      __syncthreads();

      for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
        float key = weight_keys_local_buff[j];
        bool has_topk = (key == topk_val);
        if (has_topk) {
          int write_index = atomicAdd(&cnt, 1);
          if (write_index >= max_sample_count) {
            break;
          }
          sample_output[offset + write_index] = in_rows[start + j];
          if (return_eids) {
            out_eids[offset + write_index] = eids[start + j];
          }
        }
      }
    }
  }
}
#endif

template <typename T>
__global__ void SampleAllKernel(T* sample_output,
                                const int* sample_offset,
                                const T* input_nodes,
                                int input_node_count,
                                const T* in_rows,
                                const T* col_ptr,
                                const T* eids,
                                T* out_eids,
                                bool return_eids) {
  int i = blockIdx.x;
  if (i >= input_node_count) return;
  T nid = input_nodes[i];
  T start = col_ptr[nid + 1];
  T end = col_ptr[nid];
  int neighbor_count = static_cast<int>(end - start);
  if (neighbor_count <= 0) return;
  int offset = sample_offset[i];
  for (int j = threadIdx.x; j < neighbor_count; j += blockDim.x) {
    sample_output[offset + j] = in_rows[start + j];
    if (return_eids) {
      out_eids[offset + j] = eids[start + j];
    }
  }
}

// A-RES algorithm
#ifdef PADDLE_WITH_CUDA
template <typename T, unsigned int ITEMS_PER_THREAD, unsigned int BLOCK_SIZE>
__launch_bounds__(BLOCK_SIZE) __global__
    void WeightedSampleKernel(T* sample_output,
                              const int* sample_offset,
                              const T* input_nodes,
                              int input_node_count,
                              const T* in_rows,
                              const T* col_ptr,
                              const float* edge_weight,
                              const T* eids,
                              int max_sample_count,
                              unsigned long long random_seed,  // NOLINT
                              T* out_eids,
                              bool return_eids) {
  int i = blockIdx.x;
  if (i >= input_node_count) return;
  int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE;
  T nid = input_nodes[i];
  T start = col_ptr[nid];
  T end = col_ptr[nid + 1];
  int neighbor_count = static_cast<int>(end - start);
  int offset = sample_offset[i];

  if (neighbor_count <= max_sample_count) {
    for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
      sample_output[offset + j] = in_rows[start + j];
      if (return_eids) {
        out_eids[offset + j] = eids[start + j];
      }
    }
  } else {
    RandomNumGen rng(gidx, random_seed);
    float weight_keys[ITEMS_PER_THREAD];
    int neighbor_idxs[ITEMS_PER_THREAD];
    using BlockRadixTopKT = paddle::framework::
        BlockRadixTopKRegister<float, BLOCK_SIZE, ITEMS_PER_THREAD, true, int>;
    __shared__ typename BlockRadixTopKT::TempStorage sort_tmp_storage;

    const int tx = threadIdx.x;
#pragma unroll
    for (int j = 0; j < ITEMS_PER_THREAD; j++) {
      int idx = BLOCK_SIZE * j + tx;
      if (idx < neighbor_count) {
        float thread_weight = edge_weight[start + idx];
        weight_keys[j] = GenKeyFromWeight(thread_weight, rng);
        neighbor_idxs[j] = idx;
      }
    }
    const int valid_count = (neighbor_count < (BLOCK_SIZE * ITEMS_PER_THREAD))
                                ? neighbor_count
                                : (BLOCK_SIZE * ITEMS_PER_THREAD);
    BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped(
        weight_keys, neighbor_idxs, max_sample_count, valid_count);
    __syncthreads();
    const int stride = BLOCK_SIZE * ITEMS_PER_THREAD - max_sample_count;

    for (int idx_offset = ITEMS_PER_THREAD * BLOCK_SIZE;
         idx_offset < neighbor_count;
         idx_offset += stride) {
#pragma unroll
      for (int j = 0; j < ITEMS_PER_THREAD; j++) {
        int local_idx = BLOCK_SIZE * j + tx - max_sample_count;
        int target_idx = idx_offset + local_idx;
        if (local_idx >= 0 && target_idx < neighbor_count) {
          float thread_weight = edge_weight[start + target_idx];
          weight_keys[j] = GenKeyFromWeight(thread_weight, rng);
          neighbor_idxs[j] = target_idx;
        }
      }
      const int iter_valid_count =
          ((neighbor_count - idx_offset) >= stride)
              ? (BLOCK_SIZE * ITEMS_PER_THREAD)
              : (max_sample_count + neighbor_count - idx_offset);
      BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped(
          weight_keys, neighbor_idxs, max_sample_count, iter_valid_count);
      __syncthreads();
    }
#pragma unroll
    for (int j = 0; j < ITEMS_PER_THREAD; j++) {
      int idx = j * BLOCK_SIZE + tx;
      if (idx < max_sample_count) {
        sample_output[offset + idx] = in_rows[start + neighbor_idxs[j]];
        if (return_eids) {
          out_eids[offset + idx] = eids[start + neighbor_idxs[j]];
        }
      }
    }
  }
}
#endif

template <typename T, typename Context>
void WeightedSampleNeighborsKernel(const Context& dev_ctx,
                                   const DenseTensor& row,
                                   const DenseTensor& col_ptr,
                                   const DenseTensor& edge_weight,
                                   const DenseTensor& x,
                                   const paddle::optional<DenseTensor>& eids,
                                   int sample_size,
                                   bool return_eids,
                                   DenseTensor* out,
                                   DenseTensor* out_count,
                                   DenseTensor* out_eids) {
  auto* row_data = row.data<T>();
  auto* col_ptr_data = col_ptr.data<T>();
  auto* weights_data = edge_weight.data<float>();
  auto* x_data = x.data<T>();
  auto* eids_data =
      (eids.get_ptr() == nullptr ? nullptr : eids.get_ptr()->data<T>());
  int bs = x.dims()[0];

  thread_local std::random_device rd;
  thread_local std::mt19937 gen(rd());
  thread_local std::uniform_int_distribution<unsigned long long>  // NOLINT
      distrib;
  unsigned long long random_seed = distrib(gen);  // NOLINT
  const bool need_neighbor_count = sample_size > SAMPLE_SIZE_THRESHOLD;

  out_count->Resize({bs});
  int* out_count_data =
      dev_ctx.template Alloc<int>(out_count);  // finally copy sample_count
  int* neighbor_count_ptr = nullptr;
  std::shared_ptr<phi::Allocation> neighbor_count;
  auto sample_count = phi::memory_utils::Alloc(
      dev_ctx.GetPlace(),
      (bs + 1) * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
  int* sample_count_ptr = reinterpret_cast<int*>(sample_count->ptr());

  int grid_size = (bs + 127) / 128;
  if (need_neighbor_count) {
    neighbor_count = phi::memory_utils::AllocShared(
        dev_ctx.GetPlace(),
        (bs + 1) * sizeof(int),
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
    neighbor_count_ptr = reinterpret_cast<int*>(neighbor_count->ptr());
    GetSampleCountAndNeighborCountKernel<T, true>
        <<<grid_size, 128, 0, dev_ctx.stream()>>>(col_ptr_data,
                                                  x_data,
                                                  sample_count_ptr,
                                                  neighbor_count_ptr,
                                                  sample_size,
                                                  bs);
  } else {
    GetSampleCountAndNeighborCountKernel<T, false>
        <<<grid_size, 128, 0, dev_ctx.stream()>>>(
            col_ptr_data, x_data, sample_count_ptr, nullptr, sample_size, bs);
  }

  auto sample_offset = phi::memory_utils::Alloc(
      dev_ctx.GetPlace(),
      (bs + 1) * sizeof(int),
      phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
  int* sample_offset_ptr = reinterpret_cast<int*>(sample_offset->ptr());

#ifdef PADDLE_WITH_CUDA
  const auto& exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
  const auto& exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
  thrust::exclusive_scan(exec_policy,
                         sample_count_ptr,
                         sample_count_ptr + bs + 1,
                         sample_offset_ptr);
  int total_sample_size = 0;
#ifdef PADDLE_WITH_CUDA
  cudaMemcpyAsync(&total_sample_size,
                  sample_offset_ptr + bs,
                  sizeof(int),
                  cudaMemcpyDeviceToHost,
                  dev_ctx.stream());
  cudaMemcpyAsync(out_count_data,
                  sample_count_ptr,
                  sizeof(int) * bs,
                  cudaMemcpyDeviceToDevice,
                  dev_ctx.stream());
  cudaStreamSynchronize(dev_ctx.stream());
#else
  hipMemcpyAsync(&total_sample_size,
                 sample_offset_ptr + bs,
                 sizeof(int),
                 hipMemcpyDeviceToHost,
                 dev_ctx.stream());
  hipMemcpyAsync(out_count_data,
                 sample_count_ptr,
                 sizeof(int) * bs,
                 hipMemcpyDeviceToDevice,
                 dev_ctx.stream());
  hipStreamSynchronize(dev_ctx.stream());
#endif

  out->Resize({static_cast<int>(total_sample_size)});
  T* out_data = dev_ctx.template Alloc<T>(out);
  T* out_eids_data = nullptr;
  if (return_eids) {
    out_eids->Resize({static_cast<int>(total_sample_size)});
    out_eids_data = dev_ctx.template Alloc<T>(out_eids);
  }

  // large sample size
#ifdef PADDLE_WITH_CUDA
  if (sample_size > SAMPLE_SIZE_THRESHOLD) {
    thrust::exclusive_scan(exec_policy,
                           neighbor_count_ptr,
                           neighbor_count_ptr + bs + 1,
                           neighbor_count_ptr);
    int* neighbor_offset = neighbor_count_ptr;
    int target_neighbor_counts;
    cudaMemcpyAsync(&target_neighbor_counts,
                    neighbor_offset + bs,
                    sizeof(int),
                    cudaMemcpyDeviceToHost,
                    dev_ctx.stream());
    cudaStreamSynchronize(dev_ctx.stream());

    auto tmh_weights = phi::memory_utils::Alloc(
        dev_ctx.GetPlace(),
        target_neighbor_counts * sizeof(float),
        phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
    float* target_weights_keys_buf_ptr =
        reinterpret_cast<float*>(tmh_weights->ptr());
    constexpr int BLOCK_SIZE = 256;
    WeightedSampleLargeKernel<T, BLOCK_SIZE>
        <<<bs, BLOCK_SIZE, 0, dev_ctx.stream()>>>(out_data,
                                                  sample_offset_ptr,
                                                  neighbor_offset,
                                                  target_weights_keys_buf_ptr,
                                                  x_data,
                                                  bs,
                                                  row_data,
                                                  col_ptr_data,
                                                  weights_data,
                                                  eids_data,
                                                  sample_size,
                                                  random_seed,
                                                  out_eids_data,
                                                  return_eids);
    cudaStreamSynchronize(dev_ctx.stream());
  } else if (sample_size <= 0) {
    SampleAllKernel<T><<<bs, 64, 0, dev_ctx.stream()>>>(out_data,
                                                        sample_offset_ptr,
                                                        x_data,
                                                        bs,
                                                        row_data,
                                                        col_ptr_data,
                                                        eids_data,
                                                        out_eids_data,
                                                        return_eids);
    cudaStreamSynchronize(dev_ctx.stream());
  } else {  // sample_size < sample_count_threshold
    using WeightedSampleFuncType = void (*)(T*,
                                            const int*,
                                            const T*,
                                            int,
                                            const T*,
                                            const T*,
                                            const float*,
                                            const T*,
                                            int,
                                            unsigned long long,  // NOLINT
                                            T*,
                                            bool);
    static const WeightedSampleFuncType func_array[7] = {
        WeightedSampleKernel<T, 4, 128>,
        WeightedSampleKernel<T, 6, 128>,
        WeightedSampleKernel<T, 4, 256>,
        WeightedSampleKernel<T, 5, 256>,
        WeightedSampleKernel<T, 6, 256>,
        WeightedSampleKernel<T, 8, 256>,
        WeightedSampleKernel<T, 8, 512>,
    };
    const int block_sizes[7] = {128, 128, 256, 256, 256, 256, 512};
    auto choose_func_idx = [](int sample_size) {
      if (sample_size <= 128) {
        return 0;
      }
      if (sample_size <= 384) {
        return (sample_size - 129) / 64 + 4;
      }
      if (sample_size <= 512) {
        return 5;
      } else {
        return 6;
      }
    };
    int func_idx = choose_func_idx(sample_size);
    int block_size = block_sizes[func_idx];
    func_array[func_idx]<<<bs, block_size, 0, dev_ctx.stream()>>>(
        out_data,
        sample_offset_ptr,
        x_data,
        bs,
        row_data,
        col_ptr_data,
        weights_data,
        eids_data,
        sample_size,
        random_seed,
        out_eids_data,
        return_eids);
    cudaStreamSynchronize(dev_ctx.stream());
  }
#endif
}

}  // namespace phi

PD_REGISTER_KERNEL(weighted_sample_neighbors,
                   GPU,
                   ALL_LAYOUT,
                   phi::WeightedSampleNeighborsKernel,
                   int,
                   int64_t) {}