未验证 提交 8f455ca0 编写于 作者: S Siming Dai 提交者: GitHub

Add weighted sample (#52013) (#53276)

Add paddle.geometric.weighted_sample_neighbors API
上级 5a69ddb9
...@@ -2081,6 +2081,15 @@ ...@@ -2081,6 +2081,15 @@
intermediate: warprnntgrad intermediate: warprnntgrad
backward : warprnnt_grad backward : warprnnt_grad
- op : weighted_sample_neighbors
args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids)
output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids)
infer_meta :
func : WeightedSampleNeighborsInferMeta
kernel :
func : weighted_sample_neighbors
optional: eids
- op : where - op : where
args : (Tensor condition, Tensor x, Tensor y) args : (Tensor condition, Tensor x, Tensor y)
output : Tensor output : Tensor
......
...@@ -3263,5 +3263,52 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -3263,5 +3263,52 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& edge_weight,
const MetaTensor& x,
const MetaTensor& eids,
int sample_size,
bool return_eids,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids) {
// GSN: GraphSampleNeighbors
auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GSNShapeCheck(row.dims(), "row");
GSNShapeCheck(col_ptr.dims(), "colptr");
GSNShapeCheck(edge_weight.dims(), "edge_weight");
GSNShapeCheck(x.dims(), "input_nodes");
if (return_eids) {
GSNShapeCheck(eids.dims(), "eids");
out_eids->set_dims({-1});
out_eids->set_dtype(row.dtype());
}
out->set_dims({-1});
out->set_dtype(row.dtype());
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -557,6 +557,17 @@ void WarprnntInferMeta(const MetaTensor& input, ...@@ -557,6 +557,17 @@ void WarprnntInferMeta(const MetaTensor& input,
MetaTensor* loss, MetaTensor* loss,
MetaTensor* warpctcgrad); MetaTensor* warpctcgrad);
void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& edge_weight,
const MetaTensor& x,
const MetaTensor& eids,
int sample_size,
bool return_eids,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids);
void WhereInferMeta(const MetaTensor& condition, void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x, const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/weighted_sample_neighbors_kernel.h"
#include <cmath>
#include <queue>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
struct GraphWeightedNode {
T node_id;
float weight_key;
T eid;
GraphWeightedNode() {
node_id = 0;
weight_key = 0;
eid = 0;
}
GraphWeightedNode(T node_id, float weight_key, T eid = 0)
: node_id(node_id), weight_key(weight_key), eid(eid) {}
void operator=(const GraphWeightedNode<T>& other) {
node_id = other.node_id;
weight_key = other.weight_key;
eid = other.eid;
}
friend bool operator>(const GraphWeightedNode<T>& n1,
const GraphWeightedNode<T>& n2) {
return n1.weight_key > n2.weight_key;
}
};
template <typename T>
void SampleWeightedNeighbors(
std::vector<T>& out_src, // NOLINT
const std::vector<float>& out_weight,
std::vector<T>& out_eids, // NOLINT
int sample_size,
std::mt19937& rng, // NOLINT
std::uniform_real_distribution<float>& dice_distribution, // NOLINT
bool return_eids) {
std::priority_queue<phi::GraphWeightedNode<T>,
std::vector<phi::GraphWeightedNode<T>>,
std::greater<phi::GraphWeightedNode<T>>>
min_heap;
for (size_t i = 0; i < out_src.size(); i++) {
float weight_key = log2(dice_distribution(rng)) * (1 / out_weight[i]);
if (static_cast<int>(i) < sample_size) {
if (!return_eids) {
min_heap.push(phi::GraphWeightedNode<T>(out_src[i], weight_key));
} else {
min_heap.push(
phi::GraphWeightedNode<T>(out_src[i], weight_key, out_eids[i]));
}
} else {
const phi::GraphWeightedNode<T>& small = min_heap.top();
phi::GraphWeightedNode<T> cmp;
if (!return_eids) {
cmp = GraphWeightedNode<T>(out_src[i], weight_key);
} else {
cmp = GraphWeightedNode<T>(out_src[i], weight_key, out_eids[i]);
}
bool flag = cmp > small;
if (flag) {
min_heap.pop();
min_heap.push(cmp);
}
}
}
int cnt = 0;
while (!min_heap.empty()) {
const phi::GraphWeightedNode<T>& tmp = min_heap.top();
out_src[cnt] = tmp.node_id;
if (return_eids) {
out_eids[cnt] = tmp.eid;
}
cnt++;
min_heap.pop();
}
}
template <typename T>
void SampleNeighbors(const T* row,
const T* col_ptr,
const float* edge_weight,
const T* eids,
const T* input,
std::vector<T>* output,
std::vector<int>* output_count,
std::vector<T>* output_eids,
int sample_size,
int bs,
bool return_eids) {
std::vector<std::vector<T>> out_src_vec;
std::vector<std::vector<float>> out_weight_vec;
std::vector<std::vector<T>> out_eids_vec;
// `sample_cumsum_sizes` record the start position and end position
// after sampling.
std::vector<int> sample_cumsum_sizes(bs + 1);
// `total_neighbors` the size of output after sample.
int total_neighbors = 0;
sample_cumsum_sizes[0] = total_neighbors;
for (int i = 0; i < bs; i++) {
T node = input[i];
int cap = col_ptr[node + 1] - col_ptr[node];
int k = cap > sample_size ? sample_size : cap;
total_neighbors += k;
sample_cumsum_sizes[i + 1] = total_neighbors;
std::vector<T> out_src;
out_src.resize(cap);
out_src_vec.emplace_back(out_src);
std::vector<float> out_weight;
out_weight.resize(cap);
out_weight_vec.emplace_back(out_weight);
if (return_eids) {
std::vector<T> out_eids;
out_eids.resize(cap);
out_eids_vec.emplace_back(out_eids);
}
}
output_count->resize(bs);
output->resize(total_neighbors);
if (return_eids) {
output_eids->resize(total_neighbors);
}
std::random_device rd;
std::mt19937 rng{rd()};
std::uniform_real_distribution<float> dice_distribution(0, 1);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Sample the neighbors in parallelism.
for (int i = 0; i < bs; i++) {
T node = input[i];
T begin = col_ptr[node], end = col_ptr[node + 1];
int cap = end - begin;
if (sample_size < cap) { // sample_size < neighbor_len
std::copy(row + begin, row + end, out_src_vec[i].begin());
std::copy(
edge_weight + begin, edge_weight + end, out_weight_vec[i].begin());
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
}
SampleWeightedNeighbors(out_src_vec[i],
out_weight_vec[i],
out_eids_vec[i],
sample_size,
rng,
dice_distribution,
return_eids);
*(output_count->data() + i) = sample_size;
} else { // sample_size >= neighbor_len, directly copy
std::copy(row + begin, row + end, out_src_vec[i].begin());
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
}
*(output_count->data() + i) = cap;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Copy the results parallelism
for (int i = 0; i < bs; i++) {
int k = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i];
std::copy(out_src_vec[i].begin(),
out_src_vec[i].begin() + k,
output->data() + sample_cumsum_sizes[i]);
if (return_eids) {
std::copy(out_eids_vec[i].begin(),
out_eids_vec[i].begin() + k,
output_eids->data() + sample_cumsum_sizes[i]);
}
}
}
template <typename T, typename Context>
void WeightedSampleNeighborsKernel(const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& edge_weight,
const DenseTensor& x,
const paddle::optional<DenseTensor>& eids,
int sample_size,
bool return_eids,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids) {
const T* row_data = row.data<T>();
const T* col_ptr_data = col_ptr.data<T>();
const float* weights_data = edge_weight.data<float>();
const T* x_data = x.data<T>();
const T* eids_data =
(eids.get_ptr() == nullptr ? nullptr : eids.get_ptr()->data<T>());
int bs = x.dims()[0];
std::vector<T> output;
std::vector<int> output_count;
std::vector<T> output_eids;
SampleNeighbors<T>(row_data,
col_ptr_data,
weights_data,
eids_data,
x_data,
&output,
&output_count,
&output_eids,
sample_size,
bs,
return_eids);
if (return_eids) {
out_eids->Resize({static_cast<int>(output_eids.size())});
T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
std::copy(output_eids.begin(), output_eids.end(), out_eids_data);
}
out->Resize({static_cast<int>(output.size())});
T* out_data = dev_ctx.template Alloc<T>(out);
std::copy(output.begin(), output.end(), out_data);
out_count->Resize({bs});
int* out_count_data = dev_ctx.template Alloc<int>(out_count);
std::copy(output_count.begin(), output_count.end(), out_count_data);
}
} // namespace phi
PD_REGISTER_KERNEL(weighted_sample_neighbors,
CPU,
ALL_LAYOUT,
phi::WeightedSampleNeighborsKernel,
int,
int64_t) {}
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
namespace paddle {
namespace framework {
template<
typename KeyT,
int BLOCK_SIZE,
bool GREATER = true,
int RADIX_BITS = 8>
class BlockRadixTopKGlobalMemory {
static_assert(cub::PowerOfTwo<RADIX_BITS>::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)),
"RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)");
static_assert(cub::PowerOfTwo<BLOCK_SIZE>::VALUE, "BLOCK_SIZE should be power of 2");
using KeyTraits = cub::Traits<KeyT>;
using UnsignedBits = typename KeyTraits::UnsignedBits;
using BlockScanT = cub::BlockScan<int, BLOCK_SIZE>;
static constexpr int RADIX_SIZE = (1 << RADIX_BITS);
static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE;
using BinBlockLoad = cub::BlockLoad<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
using BinBlockStore = cub::BlockStore<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
struct _TempStorage {
typename BlockScanT::TempStorage scan_storage;
union {
typename BinBlockLoad::TempStorage load_storage;
typename BinBlockStore::TempStorage store_storage;
} load_store;
union {
int shared_bins[RADIX_SIZE];
};
int share_target_k;
int share_bucket_id;
};
public:
struct TempStorage : cub::Uninitialized<_TempStorage> {
};
__device__ __forceinline__ BlockRadixTopKGlobalMemory(TempStorage &temp_storage)
: temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){};
__device__ __forceinline__ void radixTopKGetThreshold(const KeyT *data, int k, int size, KeyT &topK, bool &topk_is_unique) {
assert(k < size && k > 0);
int target_k = k;
UnsignedBits key_pattern = 0;
int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS;
for (; digit_pos >= 0; digit_pos -= RADIX_BITS) {
UpdateSharedBins(data, size, digit_pos, key_pattern);
InclusiveScanBins();
UpdateTopK(digit_pos, target_k, key_pattern);
if (target_k == 0) break;
}
if (target_k == 0) {
key_pattern -= 1;
topk_is_unique = true;
} else {
topk_is_unique = false;
}
if (GREATER) key_pattern = ~key_pattern;
UnsignedBits topK_unsigned = KeyTraits::TwiddleOut(key_pattern);
topK = reinterpret_cast<KeyT &>(topK_unsigned);
}
private:
__device__ __forceinline__ void UpdateSharedBins(const KeyT *key, int size, int digit_pos, UnsignedBits key_pattern) {
for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) {
temp_storage_.shared_bins[id] = 0;
}
cub::CTA_SYNC();
UnsignedBits key_mask = ((UnsignedBits)(-1)) << ((UnsignedBits)(digit_pos + RADIX_BITS));
#pragma unroll
for (int idx = tid_; idx < size; idx += BLOCK_SIZE) {
KeyT key_data = key[idx];
UnsignedBits twiddled_data = KeyTraits::TwiddleIn(reinterpret_cast<UnsignedBits &>(key_data));
if (GREATER) twiddled_data = ~twiddled_data;
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(twiddled_data, digit_pos, RADIX_BITS);
if ((twiddled_data & key_mask) == (key_pattern & key_mask)) {
atomicAdd(&temp_storage_.shared_bins[digit_in_radix], 1);
}
}
cub::CTA_SYNC();
}
__device__ __forceinline__ void InclusiveScanBins() {
int items[SCAN_ITEMS_PER_THREAD];
BinBlockLoad(temp_storage_.load_store.load_storage).Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0);
cub::CTA_SYNC();
BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items);
cub::CTA_SYNC();
BinBlockStore(temp_storage_.load_store.store_storage).Store(temp_storage_.shared_bins, items, RADIX_SIZE);
cub::CTA_SYNC();
}
__device__ __forceinline__ void UpdateTopK(int digit_pos,
int &target_k,
UnsignedBits &target_pattern) {
for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) {
int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1];
int cur_count = temp_storage_.shared_bins[idx];
if (prev_count <= target_k && cur_count > target_k) {
temp_storage_.share_target_k = target_k - prev_count;
temp_storage_.share_bucket_id = idx;
}
}
cub::CTA_SYNC();
target_k = temp_storage_.share_target_k;
int target_bucket_id = temp_storage_.share_bucket_id;
UnsignedBits key_segment = ((UnsignedBits) target_bucket_id) << ((UnsignedBits) digit_pos);
target_pattern |= key_segment;
}
_TempStorage &temp_storage_;
int tid_;
};
template<
typename KeyT,
int BLOCK_SIZE,
int ITEMS_PER_THREAD,
bool GREATER = true,
typename ValueT = cub::NullType,
int RADIX_BITS = 8>
class BlockRadixTopKRegister {
static_assert(cub::PowerOfTwo<RADIX_BITS>::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)),
"RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)");
static_assert(cub::PowerOfTwo<BLOCK_SIZE>::VALUE, "BLOCK_SIZE should be power of 2");
using KeyTraits = cub::Traits<KeyT>;
using UnsignedBits = typename KeyTraits::UnsignedBits;
using BlockScanT = cub::BlockScan<int, BLOCK_SIZE>;
static constexpr int RADIX_SIZE = (1 << RADIX_BITS);
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE;
using BinBlockLoad = cub::BlockLoad<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
using BinBlockStore = cub::BlockStore<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
using BlockExchangeKey = cub::BlockExchange<KeyT, BLOCK_SIZE, ITEMS_PER_THREAD>;
using BlockExchangeValue = cub::BlockExchange<ValueT, BLOCK_SIZE, ITEMS_PER_THREAD>;
using _ExchangeKeyTempStorage = typename BlockExchangeKey::TempStorage;
using _ExchangeValueTempStorage = typename BlockExchangeValue::TempStorage;
typedef union ExchangeKeyTempStorageType {
_ExchangeKeyTempStorage key_storage;
} ExchKeyTempStorageType;
typedef union ExchangeKeyValueTempStorageType {
_ExchangeKeyTempStorage key_storage;
_ExchangeValueTempStorage value_storage;
} ExchKeyValueTempStorageType;
using _ExchangeType = typename std::conditional<KEYS_ONLY, ExchKeyTempStorageType, ExchKeyValueTempStorageType>::type;
struct _TempStorage {
typename BlockScanT::TempStorage scan_storage;
union {
typename BinBlockLoad::TempStorage load_storage;
typename BinBlockStore::TempStorage store_storage;
} load_store;
union {
int shared_bins[RADIX_SIZE];
_ExchangeType exchange_storage;
};
int share_target_k;
int share_bucket_id;
int share_prev_count;
};
public:
struct TempStorage : cub::Uninitialized<_TempStorage> {
};
__device__ __forceinline__ BlockRadixTopKRegister(TempStorage &temp_storage)
: temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){};
__device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
const int k, const int valid_count) {
TopKGenRank(keys, k, valid_count);
int is_valid[ITEMS_PER_THREAD];
GenValidArray(is_valid, k);
BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged(keys, keys, ranks_, is_valid);
cub::CTA_SYNC();
}
__device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], ValueT (&values)[ITEMS_PER_THREAD],
const int k, const int valid_count) {
TopKGenRank(keys, k, valid_count);
int is_valid[ITEMS_PER_THREAD];
GenValidArray(is_valid, k);
BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged(keys, keys, ranks_, is_valid);
cub::CTA_SYNC();
BlockExchangeValue{temp_storage_.exchange_storage.value_storage}.ScatterToStripedFlagged(values, values, ranks_, is_valid);
cub::CTA_SYNC();
}
private:
__device__ __forceinline__ void TopKGenRank(KeyT (&keys)[ITEMS_PER_THREAD], const int k, const int valid_count) {
assert(k <= BLOCK_SIZE * ITEMS_PER_THREAD);
assert(k <= valid_count);
if (k == valid_count) return;
UnsignedBits(&unsigned_keys)[ITEMS_PER_THREAD] = reinterpret_cast<UnsignedBits(&)[ITEMS_PER_THREAD]>(keys);
search_mask_ = 0;
top_k_mask_ = 0;
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
int idx = KEY * BLOCK_SIZE + tid_;
unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY];
if (idx < valid_count) search_mask_ |= (1U << KEY);
}
int target_k = k;
int prefix_k = 0;
for (int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) {
UpdateSharedBins(unsigned_keys, digit_pos, prefix_k);
InclusiveScanBins();
UpdateTopK(unsigned_keys, digit_pos, target_k, prefix_k, digit_pos == 0);
if (target_k == 0) break;
}
#pragma unroll
for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY];
unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
}
}
__device__ __forceinline__ void GenValidArray(int (&is_valid)[ITEMS_PER_THREAD], int k) {
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
if ((top_k_mask_ & (1U << KEY)) && ranks_[KEY] < k) {
is_valid[KEY] = 1;
} else {
is_valid[KEY] = 0;
}
}
}
__device__ __forceinline__ void UpdateSharedBins(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
int digit_pos, int prefix_k) {
for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) {
temp_storage_.shared_bins[id] = 0;
}
cub::CTA_SYNC();
//#define USE_MATCH
#ifdef USE_MATCH
int lane_mask = cub::LaneMaskLt();
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
bool is_search = search_mask_ & (1U << KEY);
int bucket_idx = -1;
if (is_search) {
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(unsigned_keys[KEY], digit_pos, RADIX_BITS);
bucket_idx = (int) digit_in_radix;
}
int warp_match_mask = __match_any_sync(0xffffffff, bucket_idx);
int same_count = __popc(warp_match_mask);
int idx_in_same_bucket = __popc(warp_match_mask & lane_mask);
int same_bucket_root_lane = __ffs(warp_match_mask) - 1;
int same_bucket_start_idx;
if (idx_in_same_bucket == 0 && is_search) {
same_bucket_start_idx = atomicAdd(&temp_storage_.shared_bins[bucket_idx], same_count);
}
same_bucket_start_idx = __shfl_sync(0xffffffff, same_bucket_start_idx, same_bucket_root_lane, 32);
if (is_search) {
ranks_[KEY] = same_bucket_start_idx + idx_in_same_bucket + prefix_k;
}
}
#else
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
bool is_search = search_mask_ & (1U << KEY);
int bucket_idx = -1;
if (is_search) {
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(unsigned_keys[KEY], digit_pos, RADIX_BITS);
bucket_idx = (int) digit_in_radix;
ranks_[KEY] = atomicAdd(&temp_storage_.shared_bins[bucket_idx], 1) + prefix_k;
}
}
#endif
cub::CTA_SYNC();
}
__device__ __forceinline__ void InclusiveScanBins() {
int items[SCAN_ITEMS_PER_THREAD];
BinBlockLoad(temp_storage_.load_store.load_storage).Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0);
cub::CTA_SYNC();
BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items);
cub::CTA_SYNC();
BinBlockStore(temp_storage_.load_store.store_storage).Store(temp_storage_.shared_bins, items, RADIX_SIZE);
cub::CTA_SYNC();
}
__device__ __forceinline__ void UpdateTopK(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
int digit_pos,
int &target_k,
int &prefix_k,
bool mark_equal) {
for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) {
int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1];
int cur_count = temp_storage_.shared_bins[idx];
if (prev_count <= target_k && cur_count > target_k) {
temp_storage_.share_target_k = target_k - prev_count;
temp_storage_.share_bucket_id = idx;
temp_storage_.share_prev_count = prev_count;
}
}
cub::CTA_SYNC();
target_k = temp_storage_.share_target_k;
prefix_k += temp_storage_.share_prev_count;
int target_bucket_id = temp_storage_.share_bucket_id;
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
if (search_mask_ & (1U << KEY)) {
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(unsigned_keys[KEY], digit_pos, RADIX_BITS);
if (digit_in_radix < target_bucket_id) {
top_k_mask_ |= (1U << KEY);
search_mask_ &= ~(1U << KEY);
} else if (digit_in_radix > target_bucket_id) {
search_mask_ &= ~(1U << KEY);
} else {
if (mark_equal) top_k_mask_ |= (1U << KEY);
}
if (digit_in_radix <= target_bucket_id) {
int prev_count = (digit_in_radix == 0) ? 0 : temp_storage_.shared_bins[digit_in_radix - 1];
ranks_[KEY] += prev_count;
}
}
}
cub::CTA_SYNC();
}
_TempStorage &temp_storage_;
int tid_;
int ranks_[ITEMS_PER_THREAD];
unsigned int search_mask_;
unsigned int top_k_mask_;
};
}; // end namespace framework
}; // end namespace paddle
#endif
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef __NVCC__
#include <cuda_runtime_api.h> // NOLINT
#endif
class RandomNumGen {
public:
__host__ __device__ __forceinline__ RandomNumGen(int gid, unsigned long long seed) {
next_random = seed + gid;
next_random ^= next_random >> 33U;
next_random *= 0xff51afd7ed558ccdUL;
next_random ^= next_random >> 33U;
next_random *= 0xc4ceb9fe1a85ec53UL;
next_random ^= next_random >> 33U;
}
__host__ __device__ __forceinline__ ~RandomNumGen() = default;
__host__ __device__ __forceinline__ void SetSeed(int seed) {
next_random = seed;
NextValue();
}
__host__ __device__ __forceinline__ unsigned long long SaveState() const {
return next_random;
}
__host__ __device__ __forceinline__ void LoadState(unsigned long long state) {
next_random = state;
}
__host__ __device__ __forceinline__ int Random() {
int ret_value = (int) (next_random & 0x7fffffffULL);
NextValue();
return ret_value;
}
__host__ __device__ __forceinline__ int RandomMod(int mod) {
return Random() % mod;
}
__host__ __device__ __forceinline__ int64_t Random64() {
int64_t ret_value = (next_random & 0x7FFFFFFFFFFFFFFFLL);
NextValue();
return ret_value;
}
__host__ __device__ __forceinline__ int64_t RandomMod64(int64_t mod) {
return Random64() % mod;
}
__host__ __device__ __forceinline__ float RandomUniformFloat(float max = 1.0f, float min = 0.0f) {
int value = (int) (next_random & 0xffffff);
auto ret_value = (float) value;
ret_value /= 0xffffffL;
ret_value *= (max - min);
ret_value += min;
NextValue();
return ret_value;
}
__host__ __device__ __forceinline__ bool RandomBool(float true_prob) {
float value = RandomUniformFloat();
return value <= true_prob;
}
__host__ __device__ __forceinline__ void NextValue() {
//next_random = next_random * (unsigned long long)0xc4ceb9fe1a85ec53UL + generator_id;
//next_random = next_random * (unsigned long long)25214903917ULL + 11;
next_random = next_random * (unsigned long long) 13173779397737131ULL + 1023456798976543201ULL;
}
private:
unsigned long long next_random = 1;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/transform.h>
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include "cub/cub.cuh"
#endif
#include "math.h" // NOLINT
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/block_radix_topk.cuh"
#include "paddle/phi/kernels/funcs/random.cuh"
#include "paddle/phi/kernels/weighted_sample_neighbors_kernel.h"
#define SAMPLE_SIZE_THRESHOLD 1024
namespace phi {
#ifdef PADDLE_WITH_CUDA
__device__ __forceinline__ float GenKeyFromWeight(
const float weight,
RandomNumGen& rng) { // NOLINT
rng.NextValue();
float u = -rng.RandomUniformFloat(1.0f, 0.5f);
long long random_num2 = 0; // NOLINT
int seed_count = -1;
do {
random_num2 = rng.Random64();
seed_count++;
} while (!random_num2);
int one_bit = __clzll(random_num2) + seed_count * 64;
u *= exp2f(-one_bit);
float logk = (log1pf(u) / logf(2.0)) * (1 / weight);
return logk;
}
#endif
template <typename T, bool NeedNeighbor = false>
__global__ void GetSampleCountAndNeighborCountKernel(const T* col_ptr,
const T* input_nodes,
int* actual_size,
int* neighbor_count,
int sample_size,
int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n) return;
T nid = input_nodes[i];
int neighbor_size = static_cast<int>(col_ptr[nid + 1] - col_ptr[nid]);
// sample_size < 0 means sample all.
int k = neighbor_size;
if (sample_size >= 0) {
k = min(neighbor_size, sample_size);
}
actual_size[i] = k;
if (NeedNeighbor) {
neighbor_count[i] = (neighbor_size <= sample_size) ? 0 : neighbor_size;
}
}
#ifdef PADDLE_WITH_CUDA
template <typename T, unsigned int BLOCK_SIZE>
__launch_bounds__(BLOCK_SIZE) __global__
void WeightedSampleLargeKernel(T* sample_output,
const int* sample_offset,
const int* target_neighbor_offset,
float* weight_keys_buf,
const T* input_nodes,
int input_node_count,
const T* in_rows,
const T* col_ptr,
const float* edge_weight,
const T* eids,
int max_sample_count,
unsigned long long random_seed, // NOLINT
T* out_eids,
bool return_eids) {
int i = blockIdx.x;
if (i >= input_node_count) return;
int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE;
T nid = input_nodes[i];
T start = col_ptr[nid + 1];
T end = col_ptr[nid];
int neighbor_count = static_cast<int>(end - start);
float* weight_keys_local_buff = weight_keys_buf + target_neighbor_offset[i];
int offset = sample_offset[i];
if (neighbor_count <= max_sample_count) {
for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
sample_output[offset + j] = in_rows[start + j];
if (return_eids) {
out_eids[offset + j] = eids[start + j];
}
}
} else {
RandomNumGen rng(gidx, random_seed);
for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
float thread_weight = edge_weight[start + j];
weight_keys_local_buff[j] =
static_cast<float>(GenKeyFromWeight(thread_weight, rng));
}
__syncthreads();
float topk_val;
bool topk_is_unique;
using BlockRadixSelectT =
paddle::framework::BlockRadixTopKGlobalMemory<float, BLOCK_SIZE, true>;
__shared__ typename BlockRadixSelectT::TempStorage share_storage;
BlockRadixSelectT{share_storage}.radixTopKGetThreshold(
weight_keys_local_buff,
max_sample_count,
neighbor_count,
topk_val,
topk_is_unique);
__shared__ int cnt;
if (threadIdx.x == 0) {
cnt = 0;
}
__syncthreads();
// We use atomicAdd 1 operations instead of binaryScan to calculate the
// write index, since we do not need to keep the relative positions of
// element.
if (topk_is_unique) {
for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
float key = weight_keys_local_buff[j];
bool has_topk = (key >= topk_val);
if (has_topk) {
int write_index = atomicAdd(&cnt, 1);
sample_output[offset + write_index] = in_rows[start + j];
if (return_eids) {
out_eids[offset + write_index] = eids[start + j];
}
}
}
} else {
for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
float key = weight_keys_local_buff[j];
bool has_topk = (key > topk_val);
if (has_topk) {
int write_index = atomicAdd(&cnt, 1);
sample_output[offset + write_index] = in_rows[start + j];
if (return_eids) {
out_eids[offset + write_index] = eids[start + j];
}
}
}
__syncthreads();
for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
float key = weight_keys_local_buff[j];
bool has_topk = (key == topk_val);
if (has_topk) {
int write_index = atomicAdd(&cnt, 1);
if (write_index >= max_sample_count) {
break;
}
sample_output[offset + write_index] = in_rows[start + j];
if (return_eids) {
out_eids[offset + write_index] = eids[start + j];
}
}
}
}
}
}
#endif
template <typename T>
__global__ void SampleAllKernel(T* sample_output,
const int* sample_offset,
const T* input_nodes,
int input_node_count,
const T* in_rows,
const T* col_ptr,
const T* eids,
T* out_eids,
bool return_eids) {
int i = blockIdx.x;
if (i >= input_node_count) return;
T nid = input_nodes[i];
T start = col_ptr[nid + 1];
T end = col_ptr[nid];
int neighbor_count = static_cast<int>(end - start);
if (neighbor_count <= 0) return;
int offset = sample_offset[i];
for (int j = threadIdx.x; j < neighbor_count; j += blockDim.x) {
sample_output[offset + j] = in_rows[start + j];
if (return_eids) {
out_eids[offset + j] = eids[start + j];
}
}
}
// A-RES algorithm
#ifdef PADDLE_WITH_CUDA
template <typename T, unsigned int ITEMS_PER_THREAD, unsigned int BLOCK_SIZE>
__launch_bounds__(BLOCK_SIZE) __global__
void WeightedSampleKernel(T* sample_output,
const int* sample_offset,
const T* input_nodes,
int input_node_count,
const T* in_rows,
const T* col_ptr,
const float* edge_weight,
const T* eids,
int max_sample_count,
unsigned long long random_seed, // NOLINT
T* out_eids,
bool return_eids) {
int i = blockIdx.x;
if (i >= input_node_count) return;
int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE;
T nid = input_nodes[i];
T start = col_ptr[nid];
T end = col_ptr[nid + 1];
int neighbor_count = static_cast<int>(end - start);
int offset = sample_offset[i];
if (neighbor_count <= max_sample_count) {
for (int j = threadIdx.x; j < neighbor_count; j += BLOCK_SIZE) {
sample_output[offset + j] = in_rows[start + j];
if (return_eids) {
out_eids[offset + j] = eids[start + j];
}
}
} else {
RandomNumGen rng(gidx, random_seed);
float weight_keys[ITEMS_PER_THREAD];
int neighbor_idxs[ITEMS_PER_THREAD];
using BlockRadixTopKT = paddle::framework::
BlockRadixTopKRegister<float, BLOCK_SIZE, ITEMS_PER_THREAD, true, int>;
__shared__ typename BlockRadixTopKT::TempStorage sort_tmp_storage;
const int tx = threadIdx.x;
#pragma unroll
for (int j = 0; j < ITEMS_PER_THREAD; j++) {
int idx = BLOCK_SIZE * j + tx;
if (idx < neighbor_count) {
float thread_weight = edge_weight[start + idx];
weight_keys[j] = GenKeyFromWeight(thread_weight, rng);
neighbor_idxs[j] = idx;
}
}
const int valid_count = (neighbor_count < (BLOCK_SIZE * ITEMS_PER_THREAD))
? neighbor_count
: (BLOCK_SIZE * ITEMS_PER_THREAD);
BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped(
weight_keys, neighbor_idxs, max_sample_count, valid_count);
__syncthreads();
const int stride = BLOCK_SIZE * ITEMS_PER_THREAD - max_sample_count;
for (int idx_offset = ITEMS_PER_THREAD * BLOCK_SIZE;
idx_offset < neighbor_count;
idx_offset += stride) {
#pragma unroll
for (int j = 0; j < ITEMS_PER_THREAD; j++) {
int local_idx = BLOCK_SIZE * j + tx - max_sample_count;
int target_idx = idx_offset + local_idx;
if (local_idx >= 0 && target_idx < neighbor_count) {
float thread_weight = edge_weight[start + target_idx];
weight_keys[j] = GenKeyFromWeight(thread_weight, rng);
neighbor_idxs[j] = target_idx;
}
}
const int iter_valid_count =
((neighbor_count - idx_offset) >= stride)
? (BLOCK_SIZE * ITEMS_PER_THREAD)
: (max_sample_count + neighbor_count - idx_offset);
BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped(
weight_keys, neighbor_idxs, max_sample_count, iter_valid_count);
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ITEMS_PER_THREAD; j++) {
int idx = j * BLOCK_SIZE + tx;
if (idx < max_sample_count) {
sample_output[offset + idx] = in_rows[start + neighbor_idxs[j]];
if (return_eids) {
out_eids[offset + idx] = eids[start + neighbor_idxs[j]];
}
}
}
}
}
#endif
template <typename T, typename Context>
void WeightedSampleNeighborsKernel(const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& edge_weight,
const DenseTensor& x,
const paddle::optional<DenseTensor>& eids,
int sample_size,
bool return_eids,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids) {
auto* row_data = row.data<T>();
auto* col_ptr_data = col_ptr.data<T>();
auto* weights_data = edge_weight.data<float>();
auto* x_data = x.data<T>();
auto* eids_data =
(eids.get_ptr() == nullptr ? nullptr : eids.get_ptr()->data<T>());
int bs = x.dims()[0];
thread_local std::random_device rd;
thread_local std::mt19937 gen(rd());
thread_local std::uniform_int_distribution<unsigned long long> // NOLINT
distrib;
unsigned long long random_seed = distrib(gen); // NOLINT
const bool need_neighbor_count = sample_size > SAMPLE_SIZE_THRESHOLD;
out_count->Resize({bs});
int* out_count_data =
dev_ctx.template Alloc<int>(out_count); // finally copy sample_count
int* neighbor_count_ptr = nullptr;
std::shared_ptr<phi::Allocation> neighbor_count;
auto sample_count = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
(bs + 1) * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int* sample_count_ptr = reinterpret_cast<int*>(sample_count->ptr());
int grid_size = (bs + 127) / 128;
if (need_neighbor_count) {
neighbor_count = phi::memory_utils::AllocShared(
dev_ctx.GetPlace(),
(bs + 1) * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
neighbor_count_ptr = reinterpret_cast<int*>(neighbor_count->ptr());
GetSampleCountAndNeighborCountKernel<T, true>
<<<grid_size, 128, 0, dev_ctx.stream()>>>(col_ptr_data,
x_data,
sample_count_ptr,
neighbor_count_ptr,
sample_size,
bs);
} else {
GetSampleCountAndNeighborCountKernel<T, false>
<<<grid_size, 128, 0, dev_ctx.stream()>>>(
col_ptr_data, x_data, sample_count_ptr, nullptr, sample_size, bs);
}
auto sample_offset = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
(bs + 1) * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int* sample_offset_ptr = reinterpret_cast<int*>(sample_offset->ptr());
#ifdef PADDLE_WITH_CUDA
const auto& exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto& exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
thrust::exclusive_scan(exec_policy,
sample_count_ptr,
sample_count_ptr + bs + 1,
sample_offset_ptr);
int total_sample_size = 0;
#ifdef PADDLE_WITH_CUDA
cudaMemcpyAsync(&total_sample_size,
sample_offset_ptr + bs,
sizeof(int),
cudaMemcpyDeviceToHost,
dev_ctx.stream());
cudaMemcpyAsync(out_count_data,
sample_count_ptr,
sizeof(int) * bs,
cudaMemcpyDeviceToDevice,
dev_ctx.stream());
cudaStreamSynchronize(dev_ctx.stream());
#else
hipMemcpyAsync(&total_sample_size,
sample_offset_ptr + bs,
sizeof(int),
hipMemcpyDeviceToHost,
dev_ctx.stream());
hipMemcpyAsync(out_count_data,
sample_count_ptr,
sizeof(int) * bs,
hipMemcpyDeviceToDevice,
dev_ctx.stream());
hipStreamSynchronize(dev_ctx.stream());
#endif
out->Resize({static_cast<int>(total_sample_size)});
T* out_data = dev_ctx.template Alloc<T>(out);
T* out_eids_data = nullptr;
if (return_eids) {
out_eids->Resize({static_cast<int>(total_sample_size)});
out_eids_data = dev_ctx.template Alloc<T>(out_eids);
}
// large sample size
#ifdef PADDLE_WITH_CUDA
if (sample_size > SAMPLE_SIZE_THRESHOLD) {
thrust::exclusive_scan(exec_policy,
neighbor_count_ptr,
neighbor_count_ptr + bs + 1,
neighbor_count_ptr);
int* neighbor_offset = neighbor_count_ptr;
int target_neighbor_counts;
cudaMemcpyAsync(&target_neighbor_counts,
neighbor_offset + bs,
sizeof(int),
cudaMemcpyDeviceToHost,
dev_ctx.stream());
cudaStreamSynchronize(dev_ctx.stream());
auto tmh_weights = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
target_neighbor_counts * sizeof(float),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
float* target_weights_keys_buf_ptr =
reinterpret_cast<float*>(tmh_weights->ptr());
constexpr int BLOCK_SIZE = 256;
WeightedSampleLargeKernel<T, BLOCK_SIZE>
<<<bs, BLOCK_SIZE, 0, dev_ctx.stream()>>>(out_data,
sample_offset_ptr,
neighbor_offset,
target_weights_keys_buf_ptr,
x_data,
bs,
row_data,
col_ptr_data,
weights_data,
eids_data,
sample_size,
random_seed,
out_eids_data,
return_eids);
cudaStreamSynchronize(dev_ctx.stream());
} else if (sample_size <= 0) {
SampleAllKernel<T><<<bs, 64, 0, dev_ctx.stream()>>>(out_data,
sample_offset_ptr,
x_data,
bs,
row_data,
col_ptr_data,
eids_data,
out_eids_data,
return_eids);
cudaStreamSynchronize(dev_ctx.stream());
} else { // sample_size < sample_count_threshold
using WeightedSampleFuncType = void (*)(T*,
const int*,
const T*,
int,
const T*,
const T*,
const float*,
const T*,
int,
unsigned long long, // NOLINT
T*,
bool);
static const WeightedSampleFuncType func_array[7] = {
WeightedSampleKernel<T, 4, 128>,
WeightedSampleKernel<T, 6, 128>,
WeightedSampleKernel<T, 4, 256>,
WeightedSampleKernel<T, 5, 256>,
WeightedSampleKernel<T, 6, 256>,
WeightedSampleKernel<T, 8, 256>,
WeightedSampleKernel<T, 8, 512>,
};
const int block_sizes[7] = {128, 128, 256, 256, 256, 256, 512};
auto choose_func_idx = [](int sample_size) {
if (sample_size <= 128) {
return 0;
}
if (sample_size <= 384) {
return (sample_size - 129) / 64 + 4;
}
if (sample_size <= 512) {
return 5;
} else {
return 6;
}
};
int func_idx = choose_func_idx(sample_size);
int block_size = block_sizes[func_idx];
func_array[func_idx]<<<bs, block_size, 0, dev_ctx.stream()>>>(
out_data,
sample_offset_ptr,
x_data,
bs,
row_data,
col_ptr_data,
weights_data,
eids_data,
sample_size,
random_seed,
out_eids_data,
return_eids);
cudaStreamSynchronize(dev_ctx.stream());
}
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(weighted_sample_neighbors,
GPU,
ALL_LAYOUT,
phi::WeightedSampleNeighborsKernel,
int,
int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphWeightedSampleNeighborsKernel(
const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& edge_weight,
const DenseTensor& x,
const paddle::optional<DenseTensor>& eids,
int sample_size,
bool return_eids,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids);
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestWeightedSampleNeighbors(unittest.TestCase):
def setUp(self):
num_nodes = 20
edges = np.random.randint(num_nodes, size=(100, 2))
edges = np.unique(edges, axis=0)
self.edges_id = np.arange(0, len(edges)).astype("int64")
sorted_edges = edges[np.argsort(edges[:, 1])]
# Calculate dst index cumsum counts, also means colptr
dst_count = np.zeros(num_nodes)
dst_src_dict = {}
for dst in range(0, num_nodes):
true_index = sorted_edges[:, 1] == dst
dst_count[dst] = np.sum(true_index)
dst_src_dict[dst] = sorted_edges[:, 0][true_index]
dst_count = dst_count.astype("int64")
colptr = np.cumsum(dst_count)
colptr = np.insert(colptr, 0, 0)
self.row = sorted_edges[:, 0].astype("int64")
self.colptr = colptr.astype("int64")
self.nodes = np.unique(np.random.randint(num_nodes, size=5)).astype(
"int64"
)
self.weight = np.ones(self.row.shape[0]).astype("float32")
self.sample_size = 5
self.dst_src_dict = dst_src_dict
def test_sample_result(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
weight = paddle.to_tensor(self.weight)
out_neighbors, out_count = paddle.geometric.weighted_sample_neighbors(
row, colptr, weight, nodes, sample_size=self.sample_size
)
out_count_cumsum = paddle.cumsum(out_count)
for i in range(len(out_count)):
if i == 0:
neighbors = out_neighbors[0 : out_count_cumsum[i]]
else:
neighbors = out_neighbors[
out_count_cumsum[i - 1] : out_count_cumsum[i]
]
# Ensure the correct sample size.
self.assertTrue(
out_count[i] == self.sample_size
or out_count[i] == len(self.dst_src_dict[self.nodes[i]])
)
# Ensure no repetitive sample neighbors.
self.assertTrue(
neighbors.shape[0] == paddle.unique(neighbors).shape[0]
)
# Ensure the correct sample neighbors.
in_neighbors = np.isin(
neighbors.numpy(), self.dst_src_dict[self.nodes[i]]
)
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_sample_result_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype
)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype
)
weight = paddle.static.data(
name="weight", shape=self.weight.shape, dtype=self.weight.dtype
)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype
)
(
out_neighbors,
out_count,
) = paddle.geometric.weighted_sample_neighbors(
row, colptr, weight, nodes, sample_size=self.sample_size
)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(
feed={
'row': self.row,
'colptr': self.colptr,
'weight': self.weight,
'nodes': self.nodes,
},
fetch_list=[out_neighbors, out_count],
)
out_neighbors, out_count = ret
out_count_cumsum = np.cumsum(out_count)
out_neighbors = np.split(out_neighbors, out_count_cumsum)[:-1]
for neighbors, node, count in zip(
out_neighbors, self.nodes, out_count
):
self.assertTrue(
count == self.sample_size
or count == len(self.dst_src_dict[node])
)
self.assertTrue(
neighbors.shape[0] == np.unique(neighbors).shape[0]
)
in_neighbors = np.isin(neighbors, self.dst_src_dict[node])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_raise_errors(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
weight = paddle.to_tensor(self.weight)
nodes = paddle.to_tensor(self.nodes)
def check_eid_error():
paddle.geometric.weighted_sample_neighbors(
row,
colptr,
weight,
nodes,
sample_size=self.sample_size,
return_eids=True,
)
self.assertRaises(ValueError, check_eid_error)
def test_sample_result_with_eids(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
weight = paddle.to_tensor(self.weight)
nodes = paddle.to_tensor(self.nodes)
eids = paddle.to_tensor(self.edges_id)
(
out_neighbors,
out_count,
out_eids,
) = paddle.geometric.weighted_sample_neighbors(
row,
colptr,
weight,
nodes,
eids=eids,
sample_size=self.sample_size,
return_eids=True,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype
)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype
)
weight = paddle.static.data(
name="weight", shape=self.weight.shape, dtype=self.weight.dtype
)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype
)
eids = paddle.static.data(
name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype
)
(
out_neighbors,
out_count,
out_eids,
) = paddle.geometric.weighted_sample_neighbors(
row,
colptr,
weight,
nodes,
sample_size=self.sample_size,
eids=eids,
return_eids=True,
)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(
feed={
'row': self.row,
'colptr': self.colptr,
'weight': self.weight,
'nodes': self.nodes,
'eids': self.edges_id,
},
fetch_list=[out_neighbors, out_count, out_eids],
)
if __name__ == "__main__":
unittest.main()
...@@ -22,6 +22,7 @@ from .math import segment_max # noqa: F401 ...@@ -22,6 +22,7 @@ from .math import segment_max # noqa: F401
from .reindex import reindex_graph # noqa: F401 from .reindex import reindex_graph # noqa: F401
from .reindex import reindex_heter_graph # noqa: F401 from .reindex import reindex_heter_graph # noqa: F401
from .sampling import sample_neighbors # noqa: F401 from .sampling import sample_neighbors # noqa: F401
from .sampling import weighted_sample_neighbors # noqa: F401
__all__ = [ __all__ = [
'send_u_recv', 'send_u_recv',
...@@ -34,4 +35,5 @@ __all__ = [ ...@@ -34,4 +35,5 @@ __all__ = [
'reindex_graph', 'reindex_graph',
'reindex_heter_graph', 'reindex_heter_graph',
'sample_neighbors', 'sample_neighbors',
'weighted_sample_neighbors',
] ]
...@@ -13,5 +13,6 @@ ...@@ -13,5 +13,6 @@
# limitations under the License. # limitations under the License.
from .neighbors import sample_neighbors # noqa: F401 from .neighbors import sample_neighbors # noqa: F401
from .neighbors import weighted_sample_neighbors # noqa: F401
__all__ = [] __all__ = []
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
__all__ = [] __all__ = []
...@@ -170,3 +170,151 @@ def sample_neighbors( ...@@ -170,3 +170,151 @@ def sample_neighbors(
if return_eids: if return_eids:
return out_neighbors, out_count, out_eids return out_neighbors, out_count, out_eids
return out_neighbors, out_count return out_neighbors, out_count
def weighted_sample_neighbors(
row,
colptr,
edge_weight,
input_nodes,
sample_size=-1,
eids=None,
return_eids=False,
name=None,
):
"""
Graph Weighted Sample Neighbors API.
This API is mainly used in Graph Learning domain, and the main purpose is to
provide high performance of graph weighted-sampling method. For example, we get the
CSC(Compressed Sparse Column) format of the input graph edges as `row` and
`colptr`, so as to convert graph data into a suitable format for sampling, and the
input `edge_weight` should also match the CSC format. Besides, `input_nodes` means
the nodes we need to sample neighbors, and `sample_sizes` means the number of neighbors
and number of layers we want to sample. This API will finally return the weighted sampled
neighbors, and the probability of being selected as a neighbor is related to its weight,
with higher weight and higher probability.
Args:
row (Tensor): One of the components of the CSC format of the input graph, and
the shape should be [num_edges, 1] or [num_edges]. The available
data type is int32, int64.
colptr (Tensor): One of the components of the CSC format of the input graph,
and the shape should be [num_nodes + 1, 1] or [num_nodes + 1].
The data type should be the same with `row`.
edge_weight (Tensor): The edge weight of the CSC format graph edges. And the shape
should be [num_edges, 1] or [num_edges]. The available data
type is float32.
input_nodes (Tensor): The input nodes we need to sample neighbors for, and the
data type should be the same with `row`.
sample_size (int, optional): The number of neighbors we need to sample. Default value is -1,
which means returning all the neighbors of the input nodes.
eids (Tensor, optional): The eid information of the input graph. If return_eids is True,
then `eids` should not be None. The data type should be the
same with `row`. Default is None.
return_eids (bool, optional): Whether to return eid information of sample edges. Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- out_neighbors (Tensor), the sample neighbors of the input nodes.
- out_count (Tensor), the number of sampling neighbors of each input node, and the shape
should be the same with `input_nodes`.
- out_eids (Tensor), if `return_eids` is True, we will return the eid information of the
sample edges.
Examples:
.. code-block:: python
import paddle
# edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4),
# (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8)
row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7]
colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13]
weight = [0.1, 0.5, 0.2, 0.5, 0.9, 1.9, 2.0, 2.1, 0.01, 0.9, 0,12, 0.59, 0.67]
nodes = [0, 8, 1, 2]
sample_size = 2
row = paddle.to_tensor(row, dtype="int64")
colptr = paddle.to_tensor(colptr, dtype="int64")
weight = paddle.to_tensor(weight, dtype="float32")
nodes = paddle.to_tensor(nodes, dtype="int64")
out_neighbors, out_count = paddle.geometric.weighted_sample_neighbors(row, colptr, weight, nodes, sample_size=sample_size)
"""
if return_eids:
if eids is None:
raise ValueError(
"`eids` should not be None if `return_eids` is True."
)
if in_dygraph_mode():
(
out_neighbors,
out_count,
out_eids,
) = _C_ops.weighted_sample_neighbors(
row,
colptr,
edge_weight,
input_nodes,
eids,
sample_size,
return_eids,
)
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
check_variable_and_dtype(
row, "row", ("int32", "int64"), "weighted_sample_neighbors"
)
check_variable_and_dtype(
colptr, "colptr", ("int32", "int64"), "weighted_sample_neighbors"
)
check_variable_and_dtype(
edge_weight,
"edge_weight",
("float32"),
"weighted_sample_neighbors",
)
check_variable_and_dtype(
input_nodes,
"input_nodes",
("int32", "int64"),
"weighted_sample_neighbors",
)
if return_eids:
check_variable_and_dtype(
eids, "eids", ("int32", "int64"), "weighted_sample_neighbors"
)
helper = LayerHelper("weighted_sample_neighbors", **locals())
out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype)
out_count = helper.create_variable_for_type_inference(dtype=row.dtype)
out_eids = helper.create_variable_for_type_inference(dtype=row.dtype)
helper.append_op(
type="weighted_sample_neighbors",
inputs={
"row": row,
"colptr": colptr,
"edge_weight": edge_weight,
"input_nodes": input_nodes,
"eids": eids if return_eids else None,
},
outputs={
"out_neighbors": out_neighbors,
"out_count": out_count,
"out_eids": out_eids,
},
attrs={
"sample_size": sample_size,
"return_eids": return_eids,
},
)
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册