graph_reindex_kernel.cu 13.3 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/graph_reindex_kernel.h"

S
Siming Dai 已提交
17 18 19 20 21 22 23 24
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
25
#include "paddle/phi/kernels/gpu/graph_reindex_funcs.h"
S
Siming Dai 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

namespace phi {

constexpr int WARP_SIZE = 32;

template <typename T, typename Context>
void FillHashTable(const Context& dev_ctx,
                   const T* input,
                   int num_input,
                   int64_t len_hashtable,
                   thrust::device_vector<T>* unique_items,
                   T* keys,
                   int* values,
                   int* key_index) {
#ifdef PADDLE_WITH_HIP
  int block = 256;
#else
  int block = 1024;
#endif
  int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
  int grid_tmp = (num_input + block - 1) / block;
  int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  // Insert data into keys and values.
  BuildHashTable<T><<<grid, block, 0, dev_ctx.stream()>>>(
      input, num_input, len_hashtable, keys, key_index);

  // Get item index count.
  thrust::device_vector<int> item_count(num_input + 1, 0);
  GetItemIndexCount<T><<<grid, block, 0, dev_ctx.stream()>>>(
      input,
      thrust::raw_pointer_cast(item_count.data()),
      num_input,
      len_hashtable,
      keys,
      key_index);

  thrust::exclusive_scan(
      item_count.begin(), item_count.end(), item_count.begin());
  size_t total_unique_items = item_count[num_input];
  unique_items->resize(total_unique_items);

  // Get unique items
  FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(
      input,
      num_input,
      len_hashtable,
      thrust::raw_pointer_cast(unique_items->data()),
      thrust::raw_pointer_cast(item_count.data()),
      keys,
      values,
      key_index);
}

template <typename T, typename Context>
void FillBufferHashTable(const Context& dev_ctx,
                         const T* input,
                         int num_input,
                         thrust::device_vector<T>* unique_items,
                         int* values,
                         int* key_index) {
#ifdef PADDLE_WITH_HIP
  int block = 256;
#else
  int block = 1024;
#endif
  int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
  int grid_tmp = (num_input + block - 1) / block;
  int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  // Insert data.
95 96
  BuildHashTable<T>
      <<<grid, block, 0, dev_ctx.stream()>>>(input, num_input, key_index);
S
Siming Dai 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288

  // Get item index count.
  thrust::device_vector<int> item_count(num_input + 1, 0);
  GetItemIndexCount<T><<<grid, block, 0, dev_ctx.stream()>>>(
      input, thrust::raw_pointer_cast(item_count.data()), num_input, key_index);

  thrust::exclusive_scan(
      item_count.begin(), item_count.end(), item_count.begin());
  size_t total_unique_items = item_count[num_input];
  unique_items->resize(total_unique_items);

  // Get unique items
  FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(
      input,
      num_input,
      thrust::raw_pointer_cast(unique_items->data()),
      thrust::raw_pointer_cast(item_count.data()),
      values,
      key_index);
}

template <typename T, typename Context>
void ResetBufferHashTable(const Context& dev_ctx,
                          const T* input,
                          int num_input,
                          thrust::device_vector<T>* unique_items,
                          int* values,
                          int* key_index) {
#ifdef PADDLE_WITH_HIP
  int block = 256;
#else
  int block = 1024;
#endif
  int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
  int grid_tmp = (unique_items->size() + block - 1) / block;
  int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  ResetHashTable<T><<<grid, block, 0, dev_ctx.stream()>>>(
      thrust::raw_pointer_cast(unique_items->data()),
      unique_items->size(),
      key_index,
      values);
}

template <typename T, typename Context>
void Reindex(const Context& dev_ctx,
             const T* inputs,
             thrust::device_ptr<T> src_outputs,
             thrust::device_vector<T>* out_nodes,
             int num_inputs,
             int num_edges) {
  out_nodes->resize(num_inputs + num_edges);
  thrust::copy(inputs, inputs + num_inputs, out_nodes->begin());
  thrust::copy(
      src_outputs, src_outputs + num_edges, out_nodes->begin() + num_inputs);
  thrust::device_vector<T> unique_nodes;
  unique_nodes.clear();

  // Fill hash table
  int64_t num = out_nodes->size();
  int64_t log_num = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
  int64_t table_size = log_num << 1;
  T* keys;
  int *values, *key_index;

#ifdef PADDLE_WITH_HIP
  hipMalloc(&keys, table_size * sizeof(T));
  hipMalloc(&values, table_size * sizeof(int));
  hipMalloc(&key_index, table_size * sizeof(int));
  hipMemset(keys, -1, table_size * sizeof(T));
  hipMemset(values, -1, table_size * sizeof(int));
  hipMemset(key_index, -1, table_size * sizeof(int));
#else
  cudaMalloc(&keys, table_size * sizeof(T));
  cudaMalloc(&values, table_size * sizeof(int));
  cudaMalloc(&key_index, table_size * sizeof(int));
  cudaMemset(keys, -1, table_size * sizeof(T));
  cudaMemset(values, -1, table_size * sizeof(int));
  cudaMemset(key_index, -1, table_size * sizeof(int));
#endif

  FillHashTable<T, Context>(dev_ctx,
                            thrust::raw_pointer_cast(out_nodes->data()),
                            out_nodes->size(),
                            table_size,
                            &unique_nodes,
                            keys,
                            values,
                            key_index);
  out_nodes->resize(unique_nodes.size());
  thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes->begin());

// Fill outputs with reindex result.
#ifdef PADDLE_WITH_HIP
  int block = 256;
#else
  int block = 1024;
#endif
  int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
  int grid_tmp = (num_edges + block - 1) / block;
  int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  ReindexSrcOutput<T><<<grid, block, 0, dev_ctx.stream()>>>(
      thrust::raw_pointer_cast(src_outputs),
      num_edges,
      table_size,
      keys,
      values);
#ifdef PADDLE_WITH_HIP
  hipFree(keys);
  hipFree(values);
  hipFree(key_index);
#else
  cudaFree(keys);
  cudaFree(values);
  cudaFree(key_index);
#endif
}

template <typename T, typename Context>
void BufferReindex(const Context& dev_ctx,
                   const T* inputs,
                   thrust::device_ptr<T> src_outputs,
                   thrust::device_vector<T>* out_nodes,
                   int num_inputs,
                   int* hashtable_value,
                   int* hashtable_index,
                   int num_edges) {
  out_nodes->resize(num_inputs + num_edges);
  thrust::copy(inputs, inputs + num_inputs, out_nodes->begin());
  thrust::copy(
      src_outputs, src_outputs + num_edges, out_nodes->begin() + num_inputs);
  thrust::device_vector<T> unique_nodes;
  unique_nodes.clear();

  // Fill hash table
  FillBufferHashTable<T, Context>(dev_ctx,
                                  thrust::raw_pointer_cast(out_nodes->data()),
                                  out_nodes->size(),
                                  &unique_nodes,
                                  hashtable_value,
                                  hashtable_index);
  out_nodes->resize(unique_nodes.size());
  thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes->begin());

// Fill outputs with reindex result.
#ifdef PADDLE_WITH_HIP
  int block = 256;
#else
  int block = 1024;
#endif
  int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
  int grid_tmp = (num_edges + block - 1) / block;
  int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  ReindexSrcOutput<T><<<grid, block, 0, dev_ctx.stream()>>>(
      thrust::raw_pointer_cast(src_outputs), num_edges, hashtable_value);

  ResetBufferHashTable<T, Context>(dev_ctx,
                                   thrust::raw_pointer_cast(out_nodes->data()),
                                   out_nodes->size(),
                                   &unique_nodes,
                                   hashtable_value,
                                   hashtable_index);
}

template <typename T, int BLOCK_WARPS, int TILE_SIZE>
__global__ void GetDstEdgeCUDAKernel(const int64_t num_rows,
                                     const int* in_rows,
                                     const int* dst_counts,
                                     const int* dst_ptr,
                                     T* dst_outputs) {
  assert(blockDim.x == WARP_SIZE);
  assert(blockDim.y == BLOCK_WARPS);

  int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const int64_t last_row =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  while (out_row < last_row) {
    const int row = in_rows[out_row];
    const int dst_sample_size = dst_counts[out_row];
    const int out_row_start = dst_ptr[out_row];
    for (int idx = threadIdx.x; idx < dst_sample_size; idx += WARP_SIZE) {
      dst_outputs[out_row_start + idx] = row;
    }
    out_row += BLOCK_WARPS;
  }
}

template <typename T, typename Context>
void GraphReindexKernel(const Context& dev_ctx,
                        const DenseTensor& x,
                        const DenseTensor& neighbors,
                        const DenseTensor& count,
289 290
                        const paddle::optional<DenseTensor>& hashtable_value,
                        const paddle::optional<DenseTensor>& hashtable_index,
S
Siming Dai 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
                        bool flag_buffer_hashtable,
                        DenseTensor* reindex_src,
                        DenseTensor* reindex_dst,
                        DenseTensor* out_nodes) {
  const T* x_data = x.data<T>();
  const T* neighbors_data = neighbors.data<T>();
  const int* count_data = count.data<int>();
  const int bs = x.dims()[0];
  const int num_edges = neighbors.dims()[0];
  reindex_src->Resize({num_edges});

  T* reindex_src_data = dev_ctx.template Alloc<T>(reindex_src);
  thrust::device_ptr<T> src_outputs(reindex_src_data);

  thrust::device_vector<T> unique_nodes;
  thrust::copy(neighbors_data, neighbors_data + num_edges, src_outputs);

  if (flag_buffer_hashtable) {
    // Here we directly use buffer tensor to act as a hash table.
    DenseTensor hashtable_value_out(hashtable_value->type());
    const auto* ph_value = hashtable_value.get_ptr();
    hashtable_value_out.ShareDataWith(*ph_value);
    DenseTensor hashtable_index_out(hashtable_index->type());
    const auto* ph_index = hashtable_index.get_ptr();
    hashtable_index_out.ShareDataWith(*ph_index);
    int* hashtable_value_data =
317
        dev_ctx.template Alloc<int>(&hashtable_value_out);
S
Siming Dai 已提交
318
    int* hashtable_index_data =
319
        dev_ctx.template Alloc<int>(&hashtable_index_out);
S
Siming Dai 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333
    BufferReindex<T, Context>(dev_ctx,
                              x_data,
                              src_outputs,
                              &unique_nodes,
                              bs,
                              hashtable_value_data,
                              hashtable_index_data,
                              num_edges);
  } else {
    Reindex<T, Context>(
        dev_ctx, x_data, src_outputs, &unique_nodes, bs, num_edges);
  }

  // Get reindex dst edge.
S
Siming Dai 已提交
334 335 336
  // Add support for multi-type edges reindex.
  int num_ac_count = count.dims()[0];
  int num_edge_types = num_ac_count / bs;
S
Siming Dai 已提交
337 338 339 340 341 342 343 344
  thrust::device_vector<int> unique_dst_reindex(bs);
  thrust::sequence(unique_dst_reindex.begin(), unique_dst_reindex.end());
  constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
  constexpr int TILE_SIZE = BLOCK_WARPS * 16;
  const dim3 block(WARP_SIZE, BLOCK_WARPS);
  const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
  reindex_dst->Resize({num_edges});
  T* reindex_dst_data = dev_ctx.template Alloc<T>(reindex_dst);
S
Siming Dai 已提交
345 346 347 348 349 350
  int begin = 0;
  for (int i = 0; i < num_edge_types; i++) {
    thrust::device_vector<int> dst_ptr(bs);
    thrust::exclusive_scan(
        count_data + i * bs, count_data + (i + 1) * bs, dst_ptr.begin());

351 352 353 354 355 356 357
    GetDstEdgeCUDAKernel<T, BLOCK_WARPS, TILE_SIZE>
        <<<grid, block, 0, dev_ctx.stream()>>>(
            bs,
            thrust::raw_pointer_cast(unique_dst_reindex.data()),
            count_data + i * bs,
            thrust::raw_pointer_cast(dst_ptr.data()),
            reindex_dst_data + begin);
S
Siming Dai 已提交
358 359 360 361 362 363

    int count_i =
        thrust::reduce(thrust::device_pointer_cast(count_data) + i * bs,
                       thrust::device_pointer_cast(count_data) + (i + 1) * bs);
    begin += count_i;
  }
S
Siming Dai 已提交
364 365 366 367 368 369 370 371 372 373

  out_nodes->Resize({static_cast<int>(unique_nodes.size())});
  T* out_nodes_data = dev_ctx.template Alloc<T>(out_nodes);
  thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data);
}

}  // namespace phi

PD_REGISTER_KERNEL(
    graph_reindex, GPU, ALL_LAYOUT, phi::GraphReindexKernel, int, int64_t) {}