diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b07bf0ecb99a57447e34b6d8e6340bb4ead133d8..1541d1890a07a7139d86ebe7e9e97088f8e1685d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2144,6 +2144,15 @@ intermediate: warprnntgrad backward : warprnnt_grad +- op : weighted_sample_neighbors + args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids) + output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids) + infer_meta : + func : WeightedSampleNeighborsInferMeta + kernel : + func : weighted_sample_neighbors + optional: eids + - op : where args : (Tensor condition, Tensor x, Tensor y) output : Tensor diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 45769cdcb591fd25600289ccbce7e2e4501dcbed..5a8e38e21fd7287c1449919f51e400692a1a6d2e 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3249,5 +3249,52 @@ void MoeInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void WeightedSampleNeighborsInferMeta(const MetaTensor& row, + const MetaTensor& col_ptr, + const MetaTensor& edge_weight, + const MetaTensor& x, + const MetaTensor& eids, + int sample_size, + bool return_eids, + MetaTensor* out, + MetaTensor* out_count, + MetaTensor* out_eids) { + // GSN: GraphSampleNeighbors + auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) { + if (dims.size() == 2) { + PADDLE_ENFORCE_EQ( + dims[1], + 1, + phi::errors::InvalidArgument("The last dim of %s should be 1 when it " + "is 2D, but we get %d", + tensor_name, + dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dims.size(), + 1, + phi::errors::InvalidArgument( + "The %s should be 1D, when it is not 2D, but we get %d", + tensor_name, + dims.size())); + } + }; + + GSNShapeCheck(row.dims(), "row"); + GSNShapeCheck(col_ptr.dims(), "colptr"); + GSNShapeCheck(edge_weight.dims(), "edge_weight"); + GSNShapeCheck(x.dims(), "input_nodes"); + if (return_eids) { + GSNShapeCheck(eids.dims(), "eids"); + out_eids->set_dims({-1}); + out_eids->set_dtype(row.dtype()); + } + + out->set_dims({-1}); + out->set_dtype(row.dtype()); + out_count->set_dims({-1}); + out_count->set_dtype(DataType::INT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index f094ea90d9a9d4f703231f61de8143d065c17d0b..993e6c21ff6ff6f1c6039ca71fa04cd865d4d0e6 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -550,6 +550,17 @@ void WarprnntInferMeta(const MetaTensor& input, MetaTensor* loss, MetaTensor* warpctcgrad); +void WeightedSampleNeighborsInferMeta(const MetaTensor& row, + const MetaTensor& col_ptr, + const MetaTensor& edge_weight, + const MetaTensor& x, + const MetaTensor& eids, + int sample_size, + bool return_eids, + MetaTensor* out, + MetaTensor* out_count, + MetaTensor* out_eids); + void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, diff --git a/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc b/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc2b6cdbdf2faa69d7681388da48577704cd827e --- /dev/null +++ b/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc @@ -0,0 +1,255 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/weighted_sample_neighbors_kernel.h" + +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +struct GraphWeightedNode { + T node_id; + float weight_key; + T eid; + GraphWeightedNode() { + node_id = 0; + weight_key = 0; + eid = 0; + } + GraphWeightedNode(T node_id, float weight_key, T eid = 0) + : node_id(node_id), weight_key(weight_key), eid(eid) {} + void operator=(const GraphWeightedNode& other) { + node_id = other.node_id; + weight_key = other.weight_key; + eid = other.eid; + } + friend bool operator>(const GraphWeightedNode& n1, + const GraphWeightedNode& n2) { + return n1.weight_key > n2.weight_key; + } +}; + +template +void SampleWeightedNeighbors( + std::vector& out_src, // NOLINT + const std::vector& out_weight, + std::vector& out_eids, // NOLINT + int sample_size, + std::mt19937& rng, // NOLINT + std::uniform_real_distribution& dice_distribution, // NOLINT + bool return_eids) { + std::priority_queue, + std::vector>, + std::greater>> + min_heap; + for (size_t i = 0; i < out_src.size(); i++) { + float weight_key = log2(dice_distribution(rng)) * (1 / out_weight[i]); + if (static_cast(i) < sample_size) { + if (!return_eids) { + min_heap.push(phi::GraphWeightedNode(out_src[i], weight_key)); + } else { + min_heap.push( + phi::GraphWeightedNode(out_src[i], weight_key, out_eids[i])); + } + } else { + const phi::GraphWeightedNode& small = min_heap.top(); + phi::GraphWeightedNode cmp; + if (!return_eids) { + cmp = GraphWeightedNode(out_src[i], weight_key); + } else { + cmp = GraphWeightedNode(out_src[i], weight_key, out_eids[i]); + } + bool flag = cmp > small; + if (flag) { + min_heap.pop(); + min_heap.push(cmp); + } + } + } + + int cnt = 0; + while (!min_heap.empty()) { + const phi::GraphWeightedNode& tmp = min_heap.top(); + out_src[cnt] = tmp.node_id; + if (return_eids) { + out_eids[cnt] = tmp.eid; + } + cnt++; + min_heap.pop(); + } +} + +template +void SampleNeighbors(const T* row, + const T* col_ptr, + const float* edge_weight, + const T* eids, + const T* input, + std::vector* output, + std::vector* output_count, + std::vector* output_eids, + int sample_size, + int bs, + bool return_eids) { + std::vector> out_src_vec; + std::vector> out_weight_vec; + std::vector> out_eids_vec; + // `sample_cumsum_sizes` record the start position and end position + // after sampling. + std::vector sample_cumsum_sizes(bs + 1); + // `total_neighbors` the size of output after sample. + int total_neighbors = 0; + sample_cumsum_sizes[0] = total_neighbors; + for (int i = 0; i < bs; i++) { + T node = input[i]; + int cap = col_ptr[node + 1] - col_ptr[node]; + int k = cap > sample_size ? sample_size : cap; + total_neighbors += k; + sample_cumsum_sizes[i + 1] = total_neighbors; + std::vector out_src; + out_src.resize(cap); + out_src_vec.emplace_back(out_src); + std::vector out_weight; + out_weight.resize(cap); + out_weight_vec.emplace_back(out_weight); + if (return_eids) { + std::vector out_eids; + out_eids.resize(cap); + out_eids_vec.emplace_back(out_eids); + } + } + + output_count->resize(bs); + output->resize(total_neighbors); + if (return_eids) { + output_eids->resize(total_neighbors); + } + + std::random_device rd; + std::mt19937 rng{rd()}; + std::uniform_real_distribution dice_distribution(0, 1); + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + // Sample the neighbors in parallelism. + for (int i = 0; i < bs; i++) { + T node = input[i]; + T begin = col_ptr[node], end = col_ptr[node + 1]; + int cap = end - begin; + if (sample_size < cap) { // sample_size < neighbor_len + std::copy(row + begin, row + end, out_src_vec[i].begin()); + std::copy( + edge_weight + begin, edge_weight + end, out_weight_vec[i].begin()); + if (return_eids) { + std::copy(eids + begin, eids + end, out_eids_vec[i].begin()); + } + SampleWeightedNeighbors(out_src_vec[i], + out_weight_vec[i], + out_eids_vec[i], + sample_size, + rng, + dice_distribution, + return_eids); + *(output_count->data() + i) = sample_size; + } else { // sample_size >= neighbor_len, directly copy + std::copy(row + begin, row + end, out_src_vec[i].begin()); + if (return_eids) { + std::copy(eids + begin, eids + end, out_eids_vec[i].begin()); + } + *(output_count->data() + i) = cap; + } + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + // Copy the results parallelism + for (int i = 0; i < bs; i++) { + int k = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i]; + std::copy(out_src_vec[i].begin(), + out_src_vec[i].begin() + k, + output->data() + sample_cumsum_sizes[i]); + if (return_eids) { + std::copy(out_eids_vec[i].begin(), + out_eids_vec[i].begin() + k, + output_eids->data() + sample_cumsum_sizes[i]); + } + } +} + +template +void WeightedSampleNeighborsKernel(const Context& dev_ctx, + const DenseTensor& row, + const DenseTensor& col_ptr, + const DenseTensor& edge_weight, + const DenseTensor& x, + const paddle::optional& eids, + int sample_size, + bool return_eids, + DenseTensor* out, + DenseTensor* out_count, + DenseTensor* out_eids) { + const T* row_data = row.data(); + const T* col_ptr_data = col_ptr.data(); + const float* weights_data = edge_weight.data(); + const T* x_data = x.data(); + const T* eids_data = + (eids.get_ptr() == nullptr ? nullptr : eids.get_ptr()->data()); + int bs = x.dims()[0]; + + std::vector output; + std::vector output_count; + std::vector output_eids; + + SampleNeighbors(row_data, + col_ptr_data, + weights_data, + eids_data, + x_data, + &output, + &output_count, + &output_eids, + sample_size, + bs, + return_eids); + + if (return_eids) { + out_eids->Resize({static_cast(output_eids.size())}); + T* out_eids_data = dev_ctx.template Alloc(out_eids); + std::copy(output_eids.begin(), output_eids.end(), out_eids_data); + } + + out->Resize({static_cast(output.size())}); + T* out_data = dev_ctx.template Alloc(out); + std::copy(output.begin(), output.end(), out_data); + out_count->Resize({bs}); + int* out_count_data = dev_ctx.template Alloc(out_count); + std::copy(output_count.begin(), output_count.end(), out_count_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL(weighted_sample_neighbors, + CPU, + ALL_LAYOUT, + phi::WeightedSampleNeighborsKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/block_radix_topk.cuh b/paddle/phi/kernels/funcs/block_radix_topk.cuh new file mode 100644 index 0000000000000000000000000000000000000000..320d8ad8fc4f3ed3d4d51290dbcc8dd675436e6e --- /dev/null +++ b/paddle/phi/kernels/funcs/block_radix_topk.cuh @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifdef PADDLE_WITH_CUDA +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { + +template< + typename KeyT, + int BLOCK_SIZE, + bool GREATER = true, + int RADIX_BITS = 8> +class BlockRadixTopKGlobalMemory { + static_assert(cub::PowerOfTwo::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)), + "RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)"); + static_assert(cub::PowerOfTwo::VALUE, "BLOCK_SIZE should be power of 2"); + using KeyTraits = cub::Traits; + using UnsignedBits = typename KeyTraits::UnsignedBits; + using BlockScanT = cub::BlockScan; + static constexpr int RADIX_SIZE = (1 << RADIX_BITS); + static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE; + using BinBlockLoad = cub::BlockLoad; + using BinBlockStore = cub::BlockStore; + struct _TempStorage { + typename BlockScanT::TempStorage scan_storage; + union { + typename BinBlockLoad::TempStorage load_storage; + typename BinBlockStore::TempStorage store_storage; + } load_store; + union { + int shared_bins[RADIX_SIZE]; + }; + int share_target_k; + int share_bucket_id; + }; + + public: + struct TempStorage : cub::Uninitialized<_TempStorage> { + }; + __device__ __forceinline__ BlockRadixTopKGlobalMemory(TempStorage &temp_storage) + : temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){}; + __device__ __forceinline__ void radixTopKGetThreshold(const KeyT *data, int k, int size, KeyT &topK, bool &topk_is_unique) { + assert(k < size && k > 0); + int target_k = k; + UnsignedBits key_pattern = 0; + int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; + for (; digit_pos >= 0; digit_pos -= RADIX_BITS) { + UpdateSharedBins(data, size, digit_pos, key_pattern); + InclusiveScanBins(); + UpdateTopK(digit_pos, target_k, key_pattern); + if (target_k == 0) break; + } + if (target_k == 0) { + key_pattern -= 1; + topk_is_unique = true; + } else { + topk_is_unique = false; + } + if (GREATER) key_pattern = ~key_pattern; + UnsignedBits topK_unsigned = KeyTraits::TwiddleOut(key_pattern); + topK = reinterpret_cast(topK_unsigned); + } + + private: + __device__ __forceinline__ void UpdateSharedBins(const KeyT *key, int size, int digit_pos, UnsignedBits key_pattern) { + for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) { + temp_storage_.shared_bins[id] = 0; + } + cub::CTA_SYNC(); + UnsignedBits key_mask = ((UnsignedBits)(-1)) << ((UnsignedBits)(digit_pos + RADIX_BITS)); +#pragma unroll + for (int idx = tid_; idx < size; idx += BLOCK_SIZE) { + KeyT key_data = key[idx]; + UnsignedBits twiddled_data = KeyTraits::TwiddleIn(reinterpret_cast(key_data)); + if (GREATER) twiddled_data = ~twiddled_data; + UnsignedBits digit_in_radix = cub::BFE(twiddled_data, digit_pos, RADIX_BITS); + if ((twiddled_data & key_mask) == (key_pattern & key_mask)) { + atomicAdd(&temp_storage_.shared_bins[digit_in_radix], 1); + } + } + cub::CTA_SYNC(); + } + __device__ __forceinline__ void InclusiveScanBins() { + int items[SCAN_ITEMS_PER_THREAD]; + BinBlockLoad(temp_storage_.load_store.load_storage).Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0); + cub::CTA_SYNC(); + BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items); + cub::CTA_SYNC(); + BinBlockStore(temp_storage_.load_store.store_storage).Store(temp_storage_.shared_bins, items, RADIX_SIZE); + cub::CTA_SYNC(); + } + __device__ __forceinline__ void UpdateTopK(int digit_pos, + int &target_k, + UnsignedBits &target_pattern) { + for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) { + int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1]; + int cur_count = temp_storage_.shared_bins[idx]; + if (prev_count <= target_k && cur_count > target_k) { + temp_storage_.share_target_k = target_k - prev_count; + temp_storage_.share_bucket_id = idx; + } + } + cub::CTA_SYNC(); + target_k = temp_storage_.share_target_k; + int target_bucket_id = temp_storage_.share_bucket_id; + UnsignedBits key_segment = ((UnsignedBits) target_bucket_id) << ((UnsignedBits) digit_pos); + target_pattern |= key_segment; + } + _TempStorage &temp_storage_; + int tid_; +}; + +template< + typename KeyT, + int BLOCK_SIZE, + int ITEMS_PER_THREAD, + bool GREATER = true, + typename ValueT = cub::NullType, + int RADIX_BITS = 8> +class BlockRadixTopKRegister { + static_assert(cub::PowerOfTwo::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)), + "RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)"); + static_assert(cub::PowerOfTwo::VALUE, "BLOCK_SIZE should be power of 2"); + using KeyTraits = cub::Traits; + using UnsignedBits = typename KeyTraits::UnsignedBits; + using BlockScanT = cub::BlockScan; + static constexpr int RADIX_SIZE = (1 << RADIX_BITS); + static constexpr bool KEYS_ONLY = std::is_same::value; + static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE; + using BinBlockLoad = cub::BlockLoad; + using BinBlockStore = cub::BlockStore; + using BlockExchangeKey = cub::BlockExchange; + using BlockExchangeValue = cub::BlockExchange; + + using _ExchangeKeyTempStorage = typename BlockExchangeKey::TempStorage; + using _ExchangeValueTempStorage = typename BlockExchangeValue::TempStorage; + typedef union ExchangeKeyTempStorageType { + _ExchangeKeyTempStorage key_storage; + } ExchKeyTempStorageType; + typedef union ExchangeKeyValueTempStorageType { + _ExchangeKeyTempStorage key_storage; + _ExchangeValueTempStorage value_storage; + } ExchKeyValueTempStorageType; + using _ExchangeType = typename std::conditional::type; + + struct _TempStorage { + typename BlockScanT::TempStorage scan_storage; + union { + typename BinBlockLoad::TempStorage load_storage; + typename BinBlockStore::TempStorage store_storage; + } load_store; + union { + int shared_bins[RADIX_SIZE]; + _ExchangeType exchange_storage; + }; + int share_target_k; + int share_bucket_id; + int share_prev_count; + }; + + public: + struct TempStorage : cub::Uninitialized<_TempStorage> { + }; + __device__ __forceinline__ BlockRadixTopKRegister(TempStorage &temp_storage) + : temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){}; + __device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], + const int k, const int valid_count) { + TopKGenRank(keys, k, valid_count); + int is_valid[ITEMS_PER_THREAD]; + GenValidArray(is_valid, k); + BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged(keys, keys, ranks_, is_valid); + cub::CTA_SYNC(); + } + __device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], ValueT (&values)[ITEMS_PER_THREAD], + const int k, const int valid_count) { + TopKGenRank(keys, k, valid_count); + int is_valid[ITEMS_PER_THREAD]; + GenValidArray(is_valid, k); + BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged(keys, keys, ranks_, is_valid); + cub::CTA_SYNC(); + BlockExchangeValue{temp_storage_.exchange_storage.value_storage}.ScatterToStripedFlagged(values, values, ranks_, is_valid); + cub::CTA_SYNC(); + } + + private: + __device__ __forceinline__ void TopKGenRank(KeyT (&keys)[ITEMS_PER_THREAD], const int k, const int valid_count) { + assert(k <= BLOCK_SIZE * ITEMS_PER_THREAD); + assert(k <= valid_count); + if (k == valid_count) return; + UnsignedBits(&unsigned_keys)[ITEMS_PER_THREAD] = reinterpret_cast(keys); + search_mask_ = 0; + top_k_mask_ = 0; + +#pragma unroll + for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { + int idx = KEY * BLOCK_SIZE + tid_; + unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]); + if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY]; + if (idx < valid_count) search_mask_ |= (1U << KEY); + } + + int target_k = k; + int prefix_k = 0; + + for (int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) { + UpdateSharedBins(unsigned_keys, digit_pos, prefix_k); + InclusiveScanBins(); + UpdateTopK(unsigned_keys, digit_pos, target_k, prefix_k, digit_pos == 0); + if (target_k == 0) break; + } + +#pragma unroll + for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { + if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY]; + unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]); + } + } + __device__ __forceinline__ void GenValidArray(int (&is_valid)[ITEMS_PER_THREAD], int k) { +#pragma unroll + for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { + if ((top_k_mask_ & (1U << KEY)) && ranks_[KEY] < k) { + is_valid[KEY] = 1; + } else { + is_valid[KEY] = 0; + } + } + } + __device__ __forceinline__ void UpdateSharedBins(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], + int digit_pos, int prefix_k) { + for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) { + temp_storage_.shared_bins[id] = 0; + } + cub::CTA_SYNC(); +//#define USE_MATCH +#ifdef USE_MATCH + int lane_mask = cub::LaneMaskLt(); +#pragma unroll + for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { + bool is_search = search_mask_ & (1U << KEY); + int bucket_idx = -1; + if (is_search) { + UnsignedBits digit_in_radix = cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); + bucket_idx = (int) digit_in_radix; + } + int warp_match_mask = __match_any_sync(0xffffffff, bucket_idx); + int same_count = __popc(warp_match_mask); + int idx_in_same_bucket = __popc(warp_match_mask & lane_mask); + int same_bucket_root_lane = __ffs(warp_match_mask) - 1; + int same_bucket_start_idx; + if (idx_in_same_bucket == 0 && is_search) { + same_bucket_start_idx = atomicAdd(&temp_storage_.shared_bins[bucket_idx], same_count); + } + same_bucket_start_idx = __shfl_sync(0xffffffff, same_bucket_start_idx, same_bucket_root_lane, 32); + if (is_search) { + ranks_[KEY] = same_bucket_start_idx + idx_in_same_bucket + prefix_k; + } + } +#else +#pragma unroll + for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { + bool is_search = search_mask_ & (1U << KEY); + int bucket_idx = -1; + if (is_search) { + UnsignedBits digit_in_radix = cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); + bucket_idx = (int) digit_in_radix; + ranks_[KEY] = atomicAdd(&temp_storage_.shared_bins[bucket_idx], 1) + prefix_k; + } + } +#endif + cub::CTA_SYNC(); + } + __device__ __forceinline__ void InclusiveScanBins() { + int items[SCAN_ITEMS_PER_THREAD]; + BinBlockLoad(temp_storage_.load_store.load_storage).Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0); + cub::CTA_SYNC(); + BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items); + cub::CTA_SYNC(); + BinBlockStore(temp_storage_.load_store.store_storage).Store(temp_storage_.shared_bins, items, RADIX_SIZE); + cub::CTA_SYNC(); + } + __device__ __forceinline__ void UpdateTopK(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], + int digit_pos, + int &target_k, + int &prefix_k, + bool mark_equal) { + for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) { + int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1]; + int cur_count = temp_storage_.shared_bins[idx]; + if (prev_count <= target_k && cur_count > target_k) { + temp_storage_.share_target_k = target_k - prev_count; + temp_storage_.share_bucket_id = idx; + temp_storage_.share_prev_count = prev_count; + } + } + cub::CTA_SYNC(); + target_k = temp_storage_.share_target_k; + prefix_k += temp_storage_.share_prev_count; + int target_bucket_id = temp_storage_.share_bucket_id; +#pragma unroll + for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { + if (search_mask_ & (1U << KEY)) { + UnsignedBits digit_in_radix = cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); + if (digit_in_radix < target_bucket_id) { + top_k_mask_ |= (1U << KEY); + search_mask_ &= ~(1U << KEY); + } else if (digit_in_radix > target_bucket_id) { + search_mask_ &= ~(1U << KEY); + } else { + if (mark_equal) top_k_mask_ |= (1U << KEY); + } + if (digit_in_radix <= target_bucket_id) { + int prev_count = (digit_in_radix == 0) ? 0 : temp_storage_.shared_bins[digit_in_radix - 1]; + ranks_[KEY] += prev_count; + } + } + } + cub::CTA_SYNC(); + } + + _TempStorage &temp_storage_; + int tid_; + int ranks_[ITEMS_PER_THREAD]; + unsigned int search_mask_; + unsigned int top_k_mask_; +}; + +}; // end namespace framework +}; // end namespace paddle +#endif diff --git a/paddle/phi/kernels/funcs/random.cuh b/paddle/phi/kernels/funcs/random.cuh new file mode 100644 index 0000000000000000000000000000000000000000..502b7e85ee97f3097ac736c4c56b0624ee423fbc --- /dev/null +++ b/paddle/phi/kernels/funcs/random.cuh @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifdef __NVCC__ +#include // NOLINT +#endif + +class RandomNumGen { + public: + __host__ __device__ __forceinline__ RandomNumGen(int gid, unsigned long long seed) { + next_random = seed + gid; + next_random ^= next_random >> 33U; + next_random *= 0xff51afd7ed558ccdUL; + next_random ^= next_random >> 33U; + next_random *= 0xc4ceb9fe1a85ec53UL; + next_random ^= next_random >> 33U; + } + __host__ __device__ __forceinline__ ~RandomNumGen() = default; + __host__ __device__ __forceinline__ void SetSeed(int seed) { + next_random = seed; + NextValue(); + } + __host__ __device__ __forceinline__ unsigned long long SaveState() const { + return next_random; + } + __host__ __device__ __forceinline__ void LoadState(unsigned long long state) { + next_random = state; + } + __host__ __device__ __forceinline__ int Random() { + int ret_value = (int) (next_random & 0x7fffffffULL); + NextValue(); + return ret_value; + } + __host__ __device__ __forceinline__ int RandomMod(int mod) { + return Random() % mod; + } + __host__ __device__ __forceinline__ int64_t Random64() { + int64_t ret_value = (next_random & 0x7FFFFFFFFFFFFFFFLL); + NextValue(); + return ret_value; + } + __host__ __device__ __forceinline__ int64_t RandomMod64(int64_t mod) { + return Random64() % mod; + } + __host__ __device__ __forceinline__ float RandomUniformFloat(float max = 1.0f, float min = 0.0f) { + int value = (int) (next_random & 0xffffff); + auto ret_value = (float) value; + ret_value /= 0xffffffL; + ret_value *= (max - min); + ret_value += min; + NextValue(); + return ret_value; + } + __host__ __device__ __forceinline__ bool RandomBool(float true_prob) { + float value = RandomUniformFloat(); + return value <= true_prob; + } + __host__ __device__ __forceinline__ void NextValue() { + //next_random = next_random * (unsigned long long)0xc4ceb9fe1a85ec53UL + generator_id; + //next_random = next_random * (unsigned long long)25214903917ULL + 11; + next_random = next_random * (unsigned long long) 13173779397737131ULL + 1023456798976543201ULL; + } + + private: + unsigned long long next_random = 1; +}; diff --git a/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d4e0ca632e04de681d6a5e6b1513ebfe424a2e62 --- /dev/null +++ b/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu @@ -0,0 +1,535 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#ifdef PADDLE_WITH_CUDA +#include +#include +#include "cub/cub.cuh" +#endif + +#include "math.h" // NOLINT +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/block_radix_topk.cuh" +#include "paddle/phi/kernels/funcs/random.cuh" +#include "paddle/phi/kernels/weighted_sample_neighbors_kernel.h" +#define SAMPLE_SIZE_THRESHOLD 1024 + +namespace phi { + +#ifdef PADDLE_WITH_CUDA +__device__ __forceinline__ float GenKeyFromWeight( + const float weight, + RandomNumGen& rng) { // NOLINT + rng.NextValue(); + float u = -rng.RandomUniformFloat(1.0f, 0.5f); + long long random_num2 = 0; // NOLINT + int seed_count = -1; + do { + random_num2 = rng.Random64(); + seed_count++; + } while (!random_num2); + int one_bit = __clzll(random_num2) + seed_count * 64; + u *= exp2f(-one_bit); + float logk = (log1pf(u) / logf(2.0)) * (1 / weight); + return logk; +} +#endif + +template +__global__ void GetSampleCountAndNeighborCountKernel(const T* col_ptr, + const T* input_nodes, + int* actual_size, + int* neighbor_count, + int sample_size, + int n) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= n) return; + T nid = input_nodes[i]; + int neighbor_size = static_cast(col_ptr[nid + 1] - col_ptr[nid]); + // sample_size < 0 means sample all. + int k = neighbor_size; + if (sample_size >= 0) { + k = min(neighbor_size, sample_size); + } + actual_size[i] = k; + if (NeedNeighbor) { + neighbor_count[i] = (neighbor_size <= sample_size) ? 0 : neighbor_size; + } +} + +#ifdef PADDLE_WITH_CUDA +template +__launch_bounds__(BLOCK_SIZE) __global__ + void WeightedSampleLargeKernel(T* sample_output, + const int* sample_offset, + const int* target_neighbor_offset, + float* weight_keys_buf, + const T* input_nodes, + int input_node_count, + const T* in_rows, + const T* col_ptr, + const float* edge_weight, + const T* eids, + int max_sample_count, + unsigned long long random_seed, // NOLINT + T* out_eids, + bool return_eids) { + int i = blockIdx.x; + if (i >= input_node_count) return; + int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + T nid = input_nodes[i]; + T start = col_ptr[nid + 1]; + T end = col_ptr[nid]; + int neighbor_count = static_cast(end - start); + + float* weight_keys_local_buff = weight_keys_buf + target_neighbor_offset[i]; + int offset = sample_offset[i]; + if (neighbor_count <= max_sample_count) { + for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) { + sample_output[offset + j] = in_rows[start + j]; + if (return_eids) { + out_eids[offset + j] = eids[start + j]; + } + } + } else { + RandomNumGen rng(gidx, random_seed); + for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) { + float thread_weight = edge_weight[start + j]; + weight_keys_local_buff[j] = + static_cast(GenKeyFromWeight(thread_weight, rng)); + } + __syncthreads(); + + float topk_val; + bool topk_is_unique; + + using BlockRadixSelectT = + paddle::framework::BlockRadixTopKGlobalMemory; + __shared__ typename BlockRadixSelectT::TempStorage share_storage; + + BlockRadixSelectT{share_storage}.radixTopKGetThreshold( + weight_keys_local_buff, + max_sample_count, + neighbor_count, + topk_val, + topk_is_unique); + __shared__ int cnt; + + if (threadIdx.x == 0) { + cnt = 0; + } + __syncthreads(); + + // We use atomicAdd 1 operations instead of binaryScan to calculate the + // write index, since we do not need to keep the relative positions of + // element. + + if (topk_is_unique) { + for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) { + float key = weight_keys_local_buff[j]; + bool has_topk = (key >= topk_val); + + if (has_topk) { + int write_index = atomicAdd(&cnt, 1); + sample_output[offset + write_index] = in_rows[start + j]; + if (return_eids) { + out_eids[offset + write_index] = eids[start + j]; + } + } + } + } else { + for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) { + float key = weight_keys_local_buff[j]; + bool has_topk = (key > topk_val); + + if (has_topk) { + int write_index = atomicAdd(&cnt, 1); + sample_output[offset + write_index] = in_rows[start + j]; + if (return_eids) { + out_eids[offset + write_index] = eids[start + j]; + } + } + } + __syncthreads(); + + for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) { + float key = weight_keys_local_buff[j]; + bool has_topk = (key == topk_val); + if (has_topk) { + int write_index = atomicAdd(&cnt, 1); + if (write_index >= max_sample_count) { + break; + } + sample_output[offset + write_index] = in_rows[start + j]; + if (return_eids) { + out_eids[offset + write_index] = eids[start + j]; + } + } + } + } + } +} +#endif + +template +__global__ void SampleAllKernel(T* sample_output, + const int* sample_offset, + const T* input_nodes, + int input_node_count, + const T* in_rows, + const T* col_ptr, + const T* eids, + T* out_eids, + bool return_eids) { + int i = blockIdx.x; + if (i >= input_node_count) return; + T nid = input_nodes[i]; + T start = col_ptr[nid + 1]; + T end = col_ptr[nid]; + int neighbor_count = static_cast(end - start); + if (neighbor_count <= 0) return; + int offset = sample_offset[i]; + for (int j = threadIdx.x; j < neighbor_count; j += blockDim.x) { + sample_output[offset + j] = in_rows[start + j]; + if (return_eids) { + out_eids[offset + j] = eids[start + j]; + } + } +} + +// A-RES algorithm +#ifdef PADDLE_WITH_CUDA +template +__launch_bounds__(BLOCK_SIZE) __global__ + void WeightedSampleKernel(T* sample_output, + const int* sample_offset, + const T* input_nodes, + int input_node_count, + const T* in_rows, + const T* col_ptr, + const float* edge_weight, + const T* eids, + int max_sample_count, + unsigned long long random_seed, // NOLINT + T* out_eids, + bool return_eids) { + int i = blockIdx.x; + if (i >= input_node_count) return; + int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + T nid = input_nodes[i]; + T start = col_ptr[nid]; + T end = col_ptr[nid + 1]; + int neighbor_count = static_cast(end - start); + int offset = sample_offset[i]; + + if (neighbor_count <= max_sample_count) { + for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) { + sample_output[offset + j] = in_rows[start + j]; + if (return_eids) { + out_eids[offset + j] = eids[start + j]; + } + } + } else { + RandomNumGen rng(gidx, random_seed); + float weight_keys[ITEMS_PER_THREAD]; + int neighbor_idxs[ITEMS_PER_THREAD]; + using BlockRadixTopKT = paddle::framework:: + BlockRadixTopKRegister; + __shared__ typename BlockRadixTopKT::TempStorage sort_tmp_storage; + + const int tx = threadIdx.x; +#pragma unroll + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int idx = BLOCK_SIZE * j + tx; + if (idx < neighbor_count) { + float thread_weight = edge_weight[start + idx]; + weight_keys[j] = GenKeyFromWeight(thread_weight, rng); + neighbor_idxs[j] = idx; + } + } + const int valid_count = (neighbor_count < (BLOCK_SIZE * ITEMS_PER_THREAD)) + ? neighbor_count + : (BLOCK_SIZE * ITEMS_PER_THREAD); + BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( + weight_keys, neighbor_idxs, max_sample_count, valid_count); + __syncthreads(); + const int stride = BLOCK_SIZE * ITEMS_PER_THREAD - max_sample_count; + + for (int idx_offset = ITEMS_PER_THREAD * BLOCK_SIZE; + idx_offset < neighbor_count; + idx_offset += stride) { +#pragma unroll + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int local_idx = BLOCK_SIZE * j + tx - max_sample_count; + int target_idx = idx_offset + local_idx; + if (local_idx >= 0 && target_idx < neighbor_count) { + float thread_weight = edge_weight[start + target_idx]; + weight_keys[j] = GenKeyFromWeight(thread_weight, rng); + neighbor_idxs[j] = target_idx; + } + } + const int iter_valid_count = + ((neighbor_count - idx_offset) >= stride) + ? (BLOCK_SIZE * ITEMS_PER_THREAD) + : (max_sample_count + neighbor_count - idx_offset); + BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( + weight_keys, neighbor_idxs, max_sample_count, iter_valid_count); + __syncthreads(); + } +#pragma unroll + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int idx = j * BLOCK_SIZE + tx; + if (idx < max_sample_count) { + sample_output[offset + idx] = in_rows[start + neighbor_idxs[j]]; + if (return_eids) { + out_eids[offset + idx] = eids[start + neighbor_idxs[j]]; + } + } + } + } +} +#endif + +template +void WeightedSampleNeighborsKernel(const Context& dev_ctx, + const DenseTensor& row, + const DenseTensor& col_ptr, + const DenseTensor& edge_weight, + const DenseTensor& x, + const paddle::optional& eids, + int sample_size, + bool return_eids, + DenseTensor* out, + DenseTensor* out_count, + DenseTensor* out_eids) { + auto* row_data = row.data(); + auto* col_ptr_data = col_ptr.data(); + auto* weights_data = edge_weight.data(); + auto* x_data = x.data(); + auto* eids_data = + (eids.get_ptr() == nullptr ? nullptr : eids.get_ptr()->data()); + int bs = x.dims()[0]; + + thread_local std::random_device rd; + thread_local std::mt19937 gen(rd()); + thread_local std::uniform_int_distribution // NOLINT + distrib; + unsigned long long random_seed = distrib(gen); // NOLINT + const bool need_neighbor_count = sample_size > SAMPLE_SIZE_THRESHOLD; + + out_count->Resize({bs}); + int* out_count_data = + dev_ctx.template Alloc(out_count); // finally copy sample_count + int* neighbor_count_ptr = nullptr; + std::shared_ptr neighbor_count; + auto sample_count = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + (bs + 1) * sizeof(int), + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + int* sample_count_ptr = reinterpret_cast(sample_count->ptr()); + + int grid_size = (bs + 127) / 128; + if (need_neighbor_count) { + neighbor_count = phi::memory_utils::AllocShared( + dev_ctx.GetPlace(), + (bs + 1) * sizeof(int), + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + neighbor_count_ptr = reinterpret_cast(neighbor_count->ptr()); + GetSampleCountAndNeighborCountKernel + <<>>(col_ptr_data, + x_data, + sample_count_ptr, + neighbor_count_ptr, + sample_size, + bs); + } else { + GetSampleCountAndNeighborCountKernel + <<>>( + col_ptr_data, x_data, sample_count_ptr, nullptr, sample_size, bs); + } + + auto sample_offset = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + (bs + 1) * sizeof(int), + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + int* sample_offset_ptr = reinterpret_cast(sample_offset->ptr()); + +#ifdef PADDLE_WITH_CUDA + const auto& exec_policy = thrust::cuda::par.on(dev_ctx.stream()); +#else + const auto& exec_policy = thrust::hip::par.on(dev_ctx.stream()); +#endif + thrust::exclusive_scan(exec_policy, + sample_count_ptr, + sample_count_ptr + bs + 1, + sample_offset_ptr); + int total_sample_size = 0; +#ifdef PADDLE_WITH_CUDA + cudaMemcpyAsync(&total_sample_size, + sample_offset_ptr + bs, + sizeof(int), + cudaMemcpyDeviceToHost, + dev_ctx.stream()); + cudaMemcpyAsync(out_count_data, + sample_count_ptr, + sizeof(int) * bs, + cudaMemcpyDeviceToDevice, + dev_ctx.stream()); + cudaStreamSynchronize(dev_ctx.stream()); +#else + hipMemcpyAsync(&total_sample_size, + sample_offset_ptr + bs, + sizeof(int), + hipMemcpyDeviceToHost, + dev_ctx.stream()); + hipMemcpyAsync(out_count_data, + sample_count_ptr, + sizeof(int) * bs, + hipMemcpyDeviceToDevice, + dev_ctx.stream()); + hipStreamSynchronize(dev_ctx.stream()); +#endif + + out->Resize({static_cast(total_sample_size)}); + T* out_data = dev_ctx.template Alloc(out); + T* out_eids_data = nullptr; + if (return_eids) { + out_eids->Resize({static_cast(total_sample_size)}); + out_eids_data = dev_ctx.template Alloc(out_eids); + } + + // large sample size +#ifdef PADDLE_WITH_CUDA + if (sample_size > SAMPLE_SIZE_THRESHOLD) { + thrust::exclusive_scan(exec_policy, + neighbor_count_ptr, + neighbor_count_ptr + bs + 1, + neighbor_count_ptr); + int* neighbor_offset = neighbor_count_ptr; + int target_neighbor_counts; + cudaMemcpyAsync(&target_neighbor_counts, + neighbor_offset + bs, + sizeof(int), + cudaMemcpyDeviceToHost, + dev_ctx.stream()); + cudaStreamSynchronize(dev_ctx.stream()); + + auto tmh_weights = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + target_neighbor_counts * sizeof(float), + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + float* target_weights_keys_buf_ptr = + reinterpret_cast(tmh_weights->ptr()); + constexpr int BLOCK_SIZE = 256; + WeightedSampleLargeKernel + <<>>(out_data, + sample_offset_ptr, + neighbor_offset, + target_weights_keys_buf_ptr, + x_data, + bs, + row_data, + col_ptr_data, + weights_data, + eids_data, + sample_size, + random_seed, + out_eids_data, + return_eids); + cudaStreamSynchronize(dev_ctx.stream()); + } else if (sample_size <= 0) { + SampleAllKernel<<>>(out_data, + sample_offset_ptr, + x_data, + bs, + row_data, + col_ptr_data, + eids_data, + out_eids_data, + return_eids); + cudaStreamSynchronize(dev_ctx.stream()); + } else { // sample_size < sample_count_threshold + using WeightedSampleFuncType = void (*)(T*, + const int*, + const T*, + int, + const T*, + const T*, + const float*, + const T*, + int, + unsigned long long, // NOLINT + T*, + bool); + static const WeightedSampleFuncType func_array[7] = { + WeightedSampleKernel, + WeightedSampleKernel, + WeightedSampleKernel, + WeightedSampleKernel, + WeightedSampleKernel, + WeightedSampleKernel, + WeightedSampleKernel, + }; + const int block_sizes[7] = {128, 128, 256, 256, 256, 256, 512}; + auto choose_func_idx = [](int sample_size) { + if (sample_size <= 128) { + return 0; + } + if (sample_size <= 384) { + return (sample_size - 129) / 64 + 4; + } + if (sample_size <= 512) { + return 5; + } else { + return 6; + } + }; + int func_idx = choose_func_idx(sample_size); + int block_size = block_sizes[func_idx]; + func_array[func_idx]<<>>( + out_data, + sample_offset_ptr, + x_data, + bs, + row_data, + col_ptr_data, + weights_data, + eids_data, + sample_size, + random_seed, + out_eids_data, + return_eids); + cudaStreamSynchronize(dev_ctx.stream()); + } +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(weighted_sample_neighbors, + GPU, + ALL_LAYOUT, + phi::WeightedSampleNeighborsKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/weighted_sample_neighbors_kernel.h b/paddle/phi/kernels/weighted_sample_neighbors_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2a0402f9fc4947117d813c23e30aba967632cb47 --- /dev/null +++ b/paddle/phi/kernels/weighted_sample_neighbors_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphWeightedSampleNeighborsKernel( + const Context& dev_ctx, + const DenseTensor& row, + const DenseTensor& col_ptr, + const DenseTensor& edge_weight, + const DenseTensor& x, + const paddle::optional& eids, + int sample_size, + bool return_eids, + DenseTensor* out, + DenseTensor* out_count, + DenseTensor* out_eids); + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_weighted_sample_neighbors.py b/python/paddle/fluid/tests/unittests/test_weighted_sample_neighbors.py new file mode 100644 index 0000000000000000000000000000000000000000..8be782b9adf297573318133a00a138853b444359 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_weighted_sample_neighbors.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestWeightedSampleNeighbors(unittest.TestCase): + def setUp(self): + num_nodes = 20 + edges = np.random.randint(num_nodes, size=(100, 2)) + edges = np.unique(edges, axis=0) + self.edges_id = np.arange(0, len(edges)).astype("int64") + sorted_edges = edges[np.argsort(edges[:, 1])] + + # Calculate dst index cumsum counts, also means colptr + dst_count = np.zeros(num_nodes) + dst_src_dict = {} + for dst in range(0, num_nodes): + true_index = sorted_edges[:, 1] == dst + dst_count[dst] = np.sum(true_index) + dst_src_dict[dst] = sorted_edges[:, 0][true_index] + dst_count = dst_count.astype("int64") + colptr = np.cumsum(dst_count) + colptr = np.insert(colptr, 0, 0) + + self.row = sorted_edges[:, 0].astype("int64") + self.colptr = colptr.astype("int64") + self.nodes = np.unique(np.random.randint(num_nodes, size=5)).astype( + "int64" + ) + self.weight = np.ones(self.row.shape[0]).astype("float32") + self.sample_size = 5 + self.dst_src_dict = dst_src_dict + + def test_sample_result(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + weight = paddle.to_tensor(self.weight) + + out_neighbors, out_count = paddle.geometric.weighted_sample_neighbors( + row, colptr, weight, nodes, sample_size=self.sample_size + ) + out_count_cumsum = paddle.cumsum(out_count) + for i in range(len(out_count)): + if i == 0: + neighbors = out_neighbors[0 : out_count_cumsum[i]] + else: + neighbors = out_neighbors[ + out_count_cumsum[i - 1] : out_count_cumsum[i] + ] + # Ensure the correct sample size. + self.assertTrue( + out_count[i] == self.sample_size + or out_count[i] == len(self.dst_src_dict[self.nodes[i]]) + ) + # Ensure no repetitive sample neighbors. + self.assertTrue( + neighbors.shape[0] == paddle.unique(neighbors).shape[0] + ) + # Ensure the correct sample neighbors. + in_neighbors = np.isin( + neighbors.numpy(), self.dst_src_dict[self.nodes[i]] + ) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data( + name="row", shape=self.row.shape, dtype=self.row.dtype + ) + colptr = paddle.static.data( + name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype + ) + weight = paddle.static.data( + name="weight", shape=self.weight.shape, dtype=self.weight.dtype + ) + nodes = paddle.static.data( + name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype + ) + + ( + out_neighbors, + out_count, + ) = paddle.geometric.weighted_sample_neighbors( + row, colptr, weight, nodes, sample_size=self.sample_size + ) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run( + feed={ + 'row': self.row, + 'colptr': self.colptr, + 'weight': self.weight, + 'nodes': self.nodes, + }, + fetch_list=[out_neighbors, out_count], + ) + out_neighbors, out_count = ret + out_count_cumsum = np.cumsum(out_count) + out_neighbors = np.split(out_neighbors, out_count_cumsum)[:-1] + for neighbors, node, count in zip( + out_neighbors, self.nodes, out_count + ): + self.assertTrue( + count == self.sample_size + or count == len(self.dst_src_dict[node]) + ) + self.assertTrue( + neighbors.shape[0] == np.unique(neighbors).shape[0] + ) + in_neighbors = np.isin(neighbors, self.dst_src_dict[node]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_raise_errors(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + weight = paddle.to_tensor(self.weight) + nodes = paddle.to_tensor(self.nodes) + + def check_eid_error(): + paddle.geometric.weighted_sample_neighbors( + row, + colptr, + weight, + nodes, + sample_size=self.sample_size, + return_eids=True, + ) + + self.assertRaises(ValueError, check_eid_error) + + def test_sample_result_with_eids(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + weight = paddle.to_tensor(self.weight) + nodes = paddle.to_tensor(self.nodes) + eids = paddle.to_tensor(self.edges_id) + + ( + out_neighbors, + out_count, + out_eids, + ) = paddle.geometric.weighted_sample_neighbors( + row, + colptr, + weight, + nodes, + eids=eids, + sample_size=self.sample_size, + return_eids=True, + ) + + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data( + name="row", shape=self.row.shape, dtype=self.row.dtype + ) + colptr = paddle.static.data( + name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype + ) + weight = paddle.static.data( + name="weight", shape=self.weight.shape, dtype=self.weight.dtype + ) + nodes = paddle.static.data( + name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype + ) + eids = paddle.static.data( + name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype + ) + + ( + out_neighbors, + out_count, + out_eids, + ) = paddle.geometric.weighted_sample_neighbors( + row, + colptr, + weight, + nodes, + sample_size=self.sample_size, + eids=eids, + return_eids=True, + ) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run( + feed={ + 'row': self.row, + 'colptr': self.colptr, + 'weight': self.weight, + 'nodes': self.nodes, + 'eids': self.edges_id, + }, + fetch_list=[out_neighbors, out_count, out_eids], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/geometric/__init__.py b/python/paddle/geometric/__init__.py index 9618bc57a203ea4c724b9ee6c406ac41c7f2ab65..6c132a529bc37ff0b54a11d993a84b0284fbbf9f 100644 --- a/python/paddle/geometric/__init__.py +++ b/python/paddle/geometric/__init__.py @@ -22,6 +22,7 @@ from .math import segment_max # noqa: F401 from .reindex import reindex_graph # noqa: F401 from .reindex import reindex_heter_graph # noqa: F401 from .sampling import sample_neighbors # noqa: F401 +from .sampling import weighted_sample_neighbors # noqa: F401 __all__ = [ 'send_u_recv', @@ -34,4 +35,5 @@ __all__ = [ 'reindex_graph', 'reindex_heter_graph', 'sample_neighbors', + 'weighted_sample_neighbors', ] diff --git a/python/paddle/geometric/sampling/__init__.py b/python/paddle/geometric/sampling/__init__.py index 2e5b24fdd60b7fbc366f23a048b0645e2df9284c..ee7bacfc9047f607f798e2722de380a6ad4b99d4 100644 --- a/python/paddle/geometric/sampling/__init__.py +++ b/python/paddle/geometric/sampling/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .neighbors import sample_neighbors # noqa: F401 +from .neighbors import weighted_sample_neighbors # noqa: F401 __all__ = [] diff --git a/python/paddle/geometric/sampling/neighbors.py b/python/paddle/geometric/sampling/neighbors.py index 093fd39617af3e2d9336cfa3ebe44c3aaa598f94..c8d907c078bad147e44aa7b89217a8aa69efa8a3 100644 --- a/python/paddle/geometric/sampling/neighbors.py +++ b/python/paddle/geometric/sampling/neighbors.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.fluid.data_feeder import check_variable_and_dtype -from paddle.fluid.framework import _non_static_mode +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode from paddle.fluid.layer_helper import LayerHelper __all__ = [] @@ -170,3 +170,151 @@ def sample_neighbors( if return_eids: return out_neighbors, out_count, out_eids return out_neighbors, out_count + + +def weighted_sample_neighbors( + row, + colptr, + edge_weight, + input_nodes, + sample_size=-1, + eids=None, + return_eids=False, + name=None, +): + """ + Graph Weighted Sample Neighbors API. + + This API is mainly used in Graph Learning domain, and the main purpose is to + provide high performance of graph weighted-sampling method. For example, we get the + CSC(Compressed Sparse Column) format of the input graph edges as `row` and + `colptr`, so as to convert graph data into a suitable format for sampling, and the + input `edge_weight` should also match the CSC format. Besides, `input_nodes` means + the nodes we need to sample neighbors, and `sample_sizes` means the number of neighbors + and number of layers we want to sample. This API will finally return the weighted sampled + neighbors, and the probability of being selected as a neighbor is related to its weight, + with higher weight and higher probability. + + Args: + row (Tensor): One of the components of the CSC format of the input graph, and + the shape should be [num_edges, 1] or [num_edges]. The available + data type is int32, int64. + colptr (Tensor): One of the components of the CSC format of the input graph, + and the shape should be [num_nodes + 1, 1] or [num_nodes + 1]. + The data type should be the same with `row`. + edge_weight (Tensor): The edge weight of the CSC format graph edges. And the shape + should be [num_edges, 1] or [num_edges]. The available data + type is float32. + input_nodes (Tensor): The input nodes we need to sample neighbors for, and the + data type should be the same with `row`. + sample_size (int, optional): The number of neighbors we need to sample. Default value is -1, + which means returning all the neighbors of the input nodes. + eids (Tensor, optional): The eid information of the input graph. If return_eids is True, + then `eids` should not be None. The data type should be the + same with `row`. Default is None. + return_eids (bool, optional): Whether to return eid information of sample edges. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + - out_neighbors (Tensor), the sample neighbors of the input nodes. + + - out_count (Tensor), the number of sampling neighbors of each input node, and the shape + should be the same with `input_nodes`. + + - out_eids (Tensor), if `return_eids` is True, we will return the eid information of the + sample edges. + + Examples: + .. code-block:: python + + import paddle + + # edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4), + # (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8) + row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7] + colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13] + weight = [0.1, 0.5, 0.2, 0.5, 0.9, 1.9, 2.0, 2.1, 0.01, 0.9, 0,12, 0.59, 0.67] + nodes = [0, 8, 1, 2] + sample_size = 2 + row = paddle.to_tensor(row, dtype="int64") + colptr = paddle.to_tensor(colptr, dtype="int64") + weight = paddle.to_tensor(weight, dtype="float32") + nodes = paddle.to_tensor(nodes, dtype="int64") + out_neighbors, out_count = paddle.geometric.weighted_sample_neighbors(row, colptr, weight, nodes, sample_size=sample_size) + + """ + + if return_eids: + if eids is None: + raise ValueError( + "`eids` should not be None if `return_eids` is True." + ) + + if in_dygraph_mode(): + ( + out_neighbors, + out_count, + out_eids, + ) = _C_ops.weighted_sample_neighbors( + row, + colptr, + edge_weight, + input_nodes, + eids, + sample_size, + return_eids, + ) + if return_eids: + return out_neighbors, out_count, out_eids + return out_neighbors, out_count + + check_variable_and_dtype( + row, "row", ("int32", "int64"), "weighted_sample_neighbors" + ) + check_variable_and_dtype( + colptr, "colptr", ("int32", "int64"), "weighted_sample_neighbors" + ) + check_variable_and_dtype( + edge_weight, + "edge_weight", + ("float32"), + "weighted_sample_neighbors", + ) + check_variable_and_dtype( + input_nodes, + "input_nodes", + ("int32", "int64"), + "weighted_sample_neighbors", + ) + if return_eids: + check_variable_and_dtype( + eids, "eids", ("int32", "int64"), "weighted_sample_neighbors" + ) + + helper = LayerHelper("weighted_sample_neighbors", **locals()) + out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype) + out_count = helper.create_variable_for_type_inference(dtype=row.dtype) + out_eids = helper.create_variable_for_type_inference(dtype=row.dtype) + helper.append_op( + type="weighted_sample_neighbors", + inputs={ + "row": row, + "colptr": colptr, + "edge_weight": edge_weight, + "input_nodes": input_nodes, + "eids": eids if return_eids else None, + }, + outputs={ + "out_neighbors": out_neighbors, + "out_count": out_count, + "out_eids": out_eids, + }, + attrs={ + "sample_size": sample_size, + "return_eids": return_eids, + }, + ) + if return_eids: + return out_neighbors, out_count, out_eids + return out_neighbors, out_count