graph_reindex_funcs.h 6.3 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/backends/gpu/gpu_context.h"
W
Wang Xin 已提交
18
#include "paddle/phi/backends/gpu/gpu_primitives.h"
S
Siming Dai 已提交
19
#include "paddle/phi/core/hostdevice.h"
20
#include "paddle/phi/kernels/graph_reindex_kernel.h"
S
Siming Dai 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202

namespace phi {

template <typename T>
inline __device__ size_t Hash(T id, int64_t size) {
  return id % size;
}

template <typename T>
inline __device__ bool AttemptInsert(
    size_t pos, T id, int index, T* keys, int* key_index) {
  if (sizeof(T) == 4) {
    const T key = atomicCAS(reinterpret_cast<unsigned int*>(&keys[pos]),
                            static_cast<unsigned int>(-1),
                            static_cast<unsigned int>(id));
    if (key == -1 || key == id) {
      atomicMin(reinterpret_cast<unsigned int*>(&key_index[pos]),  // NOLINT
                static_cast<unsigned int>(index));                 // NOLINT
      return true;
    } else {
      return false;
    }
  } else if (sizeof(T) == 8) {
    const T key = atomicCAS(
        reinterpret_cast<unsigned long long int*>(&keys[pos]),  // NOLINT
        static_cast<unsigned long long int>(-1),                // NOLINT
        static_cast<unsigned long long int>(id));               // NOLINT
    if (key == -1 || key == id) {
      atomicMin(reinterpret_cast<unsigned int*>(&key_index[pos]),  // NOLINT
                static_cast<unsigned int>(index));                 // NOLINT
      return true;
    } else {
      return false;
    }
  }
}

template <typename T>
inline __device__ void Insert(
    T id, int index, int64_t size, T* keys, int* key_index) {
  size_t pos = Hash(id, size);
  size_t delta = 1;
  while (!AttemptInsert(pos, id, index, keys, key_index)) {
    pos = Hash(pos + delta, size);
    delta += 1;
  }
}

template <typename T>
inline __device__ int64_t Search(T id, const T* keys, int64_t size) {
  int64_t pos = Hash(id, size);

  int64_t delta = 1;
  while (keys[pos] != id) {
    pos = Hash(pos + delta, size);
    delta += 1;
  }

  return pos;
}

template <typename T>
__global__ void BuildHashTable(
    const T* items, int num_items, int64_t size, T* keys, int* key_index) {
  CUDA_KERNEL_LOOP(index, num_items) {
    Insert(items[index], index, size, keys, key_index);
  }
}

template <typename T>
__global__ void BuildHashTable(const T* items, int num_items, int* key_index) {
  CUDA_KERNEL_LOOP(index, num_items) {
    atomicMin(
        reinterpret_cast<unsigned int*>(&key_index[items[index]]),  // NOLINT
        static_cast<unsigned int>(index));                          // NOLINT
  }
}

template <typename T>
__global__ void ResetHashTable(const T* items,
                               int num_items,
                               int* key_index,
                               int* values) {
  CUDA_KERNEL_LOOP(index, num_items) {
    key_index[items[index]] = -1;
    values[items[index]] = -1;
  }
}

template <typename T>
__global__ void GetItemIndexCount(const T* items,
                                  int* item_count,
                                  int num_items,
                                  int64_t size,
                                  const T* keys,
                                  int* key_index) {
  CUDA_KERNEL_LOOP(i, num_items) {
    int64_t pos = Search(items[i], keys, size);
    if (key_index[pos] == i) {
      item_count[i] = 1;
    }
  }
}

template <typename T>
__global__ void GetItemIndexCount(const T* items,
                                  int* item_count,
                                  int num_items,
                                  int* key_index) {
  CUDA_KERNEL_LOOP(i, num_items) {
    if (key_index[items[i]] == i) {
      item_count[i] = 1;
    }
  }
}

template <typename T>
__global__ void FillUniqueItems(const T* items,
                                int num_items,
                                int64_t size,
                                T* unique_items,
                                const int* item_count,
                                const T* keys,
                                int* values,
                                int* key_index) {
  CUDA_KERNEL_LOOP(i, num_items) {
    int64_t pos = Search(items[i], keys, size);
    if (key_index[pos] == i) {
      values[pos] = item_count[i];
      unique_items[item_count[i]] = items[i];
    }
  }
}

template <typename T>
__global__ void FillUniqueItems(const T* items,
                                int num_items,
                                T* unique_items,
                                const int* item_count,
                                int* values,
                                int* key_index) {
  CUDA_KERNEL_LOOP(i, num_items) {
    if (key_index[items[i]] == i) {
      values[items[i]] = item_count[i];
      unique_items[item_count[i]] = items[i];
    }
  }
}

template <typename T>
__global__ void ReindexSrcOutput(T* src_output,
                                 int num_items,
                                 int64_t size,
                                 const T* keys,
                                 const int* values) {
  CUDA_KERNEL_LOOP(i, num_items) {
    int64_t pos = Search(src_output[i], keys, size);
    src_output[i] = values[pos];
  }
}

template <typename T>
__global__ void ReindexSrcOutput(T* src_output,
                                 int num_items,
                                 const int* values) {
  CUDA_KERNEL_LOOP(i, num_items) { src_output[i] = values[src_output[i]]; }
}

template <typename T>
__global__ void ReindexInputNodes(const T* nodes,
                                  int num_items,
                                  T* reindex_nodes,
                                  int64_t size,
                                  const T* keys,
                                  const int* values) {
  CUDA_KERNEL_LOOP(i, num_items) {
    int64_t pos = Search(nodes[i], keys, size);
    reindex_nodes[i] = values[pos];
  }
}

}  // namespace phi