未验证 提交 6a8d98e0 编写于 作者: S Siming Dai 提交者: GitHub

Add weighted sample (#52013)

Add paddle.geometric.weighted_sample_neighbors API
上级 987fb2d8
......@@ -2144,6 +2144,15 @@
intermediate: warprnntgrad
backward : warprnnt_grad
- op : weighted_sample_neighbors
args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids)
output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids)
infer_meta :
func : WeightedSampleNeighborsInferMeta
kernel :
func : weighted_sample_neighbors
optional: eids
- op : where
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor
......
......@@ -3249,5 +3249,52 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& edge_weight,
const MetaTensor& x,
const MetaTensor& eids,
int sample_size,
bool return_eids,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids) {
// GSN: GraphSampleNeighbors
auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GSNShapeCheck(row.dims(), "row");
GSNShapeCheck(col_ptr.dims(), "colptr");
GSNShapeCheck(edge_weight.dims(), "edge_weight");
GSNShapeCheck(x.dims(), "input_nodes");
if (return_eids) {
GSNShapeCheck(eids.dims(), "eids");
out_eids->set_dims({-1});
out_eids->set_dtype(row.dtype());
}
out->set_dims({-1});
out->set_dtype(row.dtype());
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -550,6 +550,17 @@ void WarprnntInferMeta(const MetaTensor& input,
MetaTensor* loss,
MetaTensor* warpctcgrad);
void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& edge_weight,
const MetaTensor& x,
const MetaTensor& eids,
int sample_size,
bool return_eids,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids);
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/weighted_sample_neighbors_kernel.h"
#include <cmath>
#include <queue>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
struct GraphWeightedNode {
T node_id;
float weight_key;
T eid;
GraphWeightedNode() {
node_id = 0;
weight_key = 0;
eid = 0;
}
GraphWeightedNode(T node_id, float weight_key, T eid = 0)
: node_id(node_id), weight_key(weight_key), eid(eid) {}
void operator=(const GraphWeightedNode<T>& other) {
node_id = other.node_id;
weight_key = other.weight_key;
eid = other.eid;
}
friend bool operator>(const GraphWeightedNode<T>& n1,
const GraphWeightedNode<T>& n2) {
return n1.weight_key > n2.weight_key;
}
};
template <typename T>
void SampleWeightedNeighbors(
std::vector<T>& out_src, // NOLINT
const std::vector<float>& out_weight,
std::vector<T>& out_eids, // NOLINT
int sample_size,
std::mt19937& rng, // NOLINT
std::uniform_real_distribution<float>& dice_distribution, // NOLINT
bool return_eids) {
std::priority_queue<phi::GraphWeightedNode<T>,
std::vector<phi::GraphWeightedNode<T>>,
std::greater<phi::GraphWeightedNode<T>>>
min_heap;
for (size_t i = 0; i < out_src.size(); i++) {
float weight_key = log2(dice_distribution(rng)) * (1 / out_weight[i]);
if (static_cast<int>(i) < sample_size) {
if (!return_eids) {
min_heap.push(phi::GraphWeightedNode<T>(out_src[i], weight_key));
} else {
min_heap.push(
phi::GraphWeightedNode<T>(out_src[i], weight_key, out_eids[i]));
}
} else {
const phi::GraphWeightedNode<T>& small = min_heap.top();
phi::GraphWeightedNode<T> cmp;
if (!return_eids) {
cmp = GraphWeightedNode<T>(out_src[i], weight_key);
} else {
cmp = GraphWeightedNode<T>(out_src[i], weight_key, out_eids[i]);
}
bool flag = cmp > small;
if (flag) {
min_heap.pop();
min_heap.push(cmp);
}
}
}
int cnt = 0;
while (!min_heap.empty()) {
const phi::GraphWeightedNode<T>& tmp = min_heap.top();
out_src[cnt] = tmp.node_id;
if (return_eids) {
out_eids[cnt] = tmp.eid;
}
cnt++;
min_heap.pop();
}
}
template <typename T>
void SampleNeighbors(const T* row,
const T* col_ptr,
const float* edge_weight,
const T* eids,
const T* input,
std::vector<T>* output,
std::vector<int>* output_count,
std::vector<T>* output_eids,
int sample_size,
int bs,
bool return_eids) {
std::vector<std::vector<T>> out_src_vec;
std::vector<std::vector<float>> out_weight_vec;
std::vector<std::vector<T>> out_eids_vec;
// `sample_cumsum_sizes` record the start position and end position
// after sampling.
std::vector<int> sample_cumsum_sizes(bs + 1);
// `total_neighbors` the size of output after sample.
int total_neighbors = 0;
sample_cumsum_sizes[0] = total_neighbors;
for (int i = 0; i < bs; i++) {
T node = input[i];
int cap = col_ptr[node + 1] - col_ptr[node];
int k = cap > sample_size ? sample_size : cap;
total_neighbors += k;
sample_cumsum_sizes[i + 1] = total_neighbors;
std::vector<T> out_src;
out_src.resize(cap);
out_src_vec.emplace_back(out_src);
std::vector<float> out_weight;
out_weight.resize(cap);
out_weight_vec.emplace_back(out_weight);
if (return_eids) {
std::vector<T> out_eids;
out_eids.resize(cap);
out_eids_vec.emplace_back(out_eids);
}
}
output_count->resize(bs);
output->resize(total_neighbors);
if (return_eids) {
output_eids->resize(total_neighbors);
}
std::random_device rd;
std::mt19937 rng{rd()};
std::uniform_real_distribution<float> dice_distribution(0, 1);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Sample the neighbors in parallelism.
for (int i = 0; i < bs; i++) {
T node = input[i];
T begin = col_ptr[node], end = col_ptr[node + 1];
int cap = end - begin;
if (sample_size < cap) { // sample_size < neighbor_len
std::copy(row + begin, row + end, out_src_vec[i].begin());
std::copy(
edge_weight + begin, edge_weight + end, out_weight_vec[i].begin());
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
}
SampleWeightedNeighbors(out_src_vec[i],
out_weight_vec[i],
out_eids_vec[i],
sample_size,
rng,
dice_distribution,
return_eids);
*(output_count->data() + i) = sample_size;
} else { // sample_size >= neighbor_len, directly copy
std::copy(row + begin, row + end, out_src_vec[i].begin());
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
}
*(output_count->data() + i) = cap;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Copy the results parallelism
for (int i = 0; i < bs; i++) {
int k = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i];
std::copy(out_src_vec[i].begin(),
out_src_vec[i].begin() + k,
output->data() + sample_cumsum_sizes[i]);
if (return_eids) {
std::copy(out_eids_vec[i].begin(),
out_eids_vec[i].begin() + k,
output_eids->data() + sample_cumsum_sizes[i]);
}
}
}
template <typename T, typename Context>
void WeightedSampleNeighborsKernel(const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& edge_weight,
const DenseTensor& x,
const paddle::optional<DenseTensor>& eids,
int sample_size,
bool return_eids,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids) {
const T* row_data = row.data<T>();
const T* col_ptr_data = col_ptr.data<T>();
const float* weights_data = edge_weight.data<float>();
const T* x_data = x.data<T>();
const T* eids_data =
(eids.get_ptr() == nullptr ? nullptr : eids.get_ptr()->data<T>());
int bs = x.dims()[0];
std::vector<T> output;
std::vector<int> output_count;
std::vector<T> output_eids;
SampleNeighbors<T>(row_data,
col_ptr_data,
weights_data,
eids_data,
x_data,
&output,
&output_count,
&output_eids,
sample_size,
bs,
return_eids);
if (return_eids) {
out_eids->Resize({static_cast<int>(output_eids.size())});
T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
std::copy(output_eids.begin(), output_eids.end(), out_eids_data);
}
out->Resize({static_cast<int>(output.size())});
T* out_data = dev_ctx.template Alloc<T>(out);
std::copy(output.begin(), output.end(), out_data);
out_count->Resize({bs});
int* out_count_data = dev_ctx.template Alloc<int>(out_count);
std::copy(output_count.begin(), output_count.end(), out_count_data);
}
} // namespace phi
PD_REGISTER_KERNEL(weighted_sample_neighbors,
CPU,
ALL_LAYOUT,
phi::WeightedSampleNeighborsKernel,
int,
int64_t) {}
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
namespace paddle {
namespace framework {
template<
typename KeyT,
int BLOCK_SIZE,
bool GREATER = true,
int RADIX_BITS = 8>
class BlockRadixTopKGlobalMemory {
static_assert(cub::PowerOfTwo<RADIX_BITS>::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)),
"RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)");
static_assert(cub::PowerOfTwo<BLOCK_SIZE>::VALUE, "BLOCK_SIZE should be power of 2");
using KeyTraits = cub::Traits<KeyT>;
using UnsignedBits = typename KeyTraits::UnsignedBits;
using BlockScanT = cub::BlockScan<int, BLOCK_SIZE>;
static constexpr int RADIX_SIZE = (1 << RADIX_BITS);
static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE;
using BinBlockLoad = cub::BlockLoad<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
using BinBlockStore = cub::BlockStore<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
struct _TempStorage {
typename BlockScanT::TempStorage scan_storage;
union {
typename BinBlockLoad::TempStorage load_storage;
typename BinBlockStore::TempStorage store_storage;
} load_store;
union {
int shared_bins[RADIX_SIZE];
};
int share_target_k;
int share_bucket_id;
};
public:
struct TempStorage : cub::Uninitialized<_TempStorage> {
};
__device__ __forceinline__ BlockRadixTopKGlobalMemory(TempStorage &temp_storage)
: temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){};
__device__ __forceinline__ void radixTopKGetThreshold(const KeyT *data, int k, int size, KeyT &topK, bool &topk_is_unique) {
assert(k < size && k > 0);
int target_k = k;
UnsignedBits key_pattern = 0;
int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS;
for (; digit_pos >= 0; digit_pos -= RADIX_BITS) {
UpdateSharedBins(data, size, digit_pos, key_pattern);
InclusiveScanBins();
UpdateTopK(digit_pos, target_k, key_pattern);
if (target_k == 0) break;
}
if (target_k == 0) {
key_pattern -= 1;
topk_is_unique = true;
} else {
topk_is_unique = false;
}
if (GREATER) key_pattern = ~key_pattern;
UnsignedBits topK_unsigned = KeyTraits::TwiddleOut(key_pattern);
topK = reinterpret_cast<KeyT &>(topK_unsigned);
}
private:
__device__ __forceinline__ void UpdateSharedBins(const KeyT *key, int size, int digit_pos, UnsignedBits key_pattern) {
for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) {
temp_storage_.shared_bins[id] = 0;
}
cub::CTA_SYNC();
UnsignedBits key_mask = ((UnsignedBits)(-1)) << ((UnsignedBits)(digit_pos + RADIX_BITS));
#pragma unroll
for (int idx = tid_; idx < size; idx += BLOCK_SIZE) {
KeyT key_data = key[idx];
UnsignedBits twiddled_data = KeyTraits::TwiddleIn(reinterpret_cast<UnsignedBits &>(key_data));
if (GREATER) twiddled_data = ~twiddled_data;
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(twiddled_data, digit_pos, RADIX_BITS);
if ((twiddled_data & key_mask) == (key_pattern & key_mask)) {
atomicAdd(&temp_storage_.shared_bins[digit_in_radix], 1);
}
}
cub::CTA_SYNC();
}
__device__ __forceinline__ void InclusiveScanBins() {
int items[SCAN_ITEMS_PER_THREAD];
BinBlockLoad(temp_storage_.load_store.load_storage).Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0);
cub::CTA_SYNC();
BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items);
cub::CTA_SYNC();
BinBlockStore(temp_storage_.load_store.store_storage).Store(temp_storage_.shared_bins, items, RADIX_SIZE);
cub::CTA_SYNC();
}
__device__ __forceinline__ void UpdateTopK(int digit_pos,
int &target_k,
UnsignedBits &target_pattern) {
for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) {
int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1];
int cur_count = temp_storage_.shared_bins[idx];
if (prev_count <= target_k && cur_count > target_k) {
temp_storage_.share_target_k = target_k - prev_count;
temp_storage_.share_bucket_id = idx;
}
}
cub::CTA_SYNC();
target_k = temp_storage_.share_target_k;
int target_bucket_id = temp_storage_.share_bucket_id;
UnsignedBits key_segment = ((UnsignedBits) target_bucket_id) << ((UnsignedBits) digit_pos);
target_pattern |= key_segment;
}
_TempStorage &temp_storage_;
int tid_;
};
template<
typename KeyT,
int BLOCK_SIZE,
int ITEMS_PER_THREAD,
bool GREATER = true,
typename ValueT = cub::NullType,
int RADIX_BITS = 8>
class BlockRadixTopKRegister {
static_assert(cub::PowerOfTwo<RADIX_BITS>::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)),
"RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)");
static_assert(cub::PowerOfTwo<BLOCK_SIZE>::VALUE, "BLOCK_SIZE should be power of 2");
using KeyTraits = cub::Traits<KeyT>;
using UnsignedBits = typename KeyTraits::UnsignedBits;
using BlockScanT = cub::BlockScan<int, BLOCK_SIZE>;
static constexpr int RADIX_SIZE = (1 << RADIX_BITS);
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE;
using BinBlockLoad = cub::BlockLoad<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
using BinBlockStore = cub::BlockStore<int, BLOCK_SIZE, SCAN_ITEMS_PER_THREAD>;
using BlockExchangeKey = cub::BlockExchange<KeyT, BLOCK_SIZE, ITEMS_PER_THREAD>;
using BlockExchangeValue = cub::BlockExchange<ValueT, BLOCK_SIZE, ITEMS_PER_THREAD>;
using _ExchangeKeyTempStorage = typename BlockExchangeKey::TempStorage;
using _ExchangeValueTempStorage = typename BlockExchangeValue::TempStorage;
typedef union ExchangeKeyTempStorageType {
_ExchangeKeyTempStorage key_storage;
} ExchKeyTempStorageType;
typedef union ExchangeKeyValueTempStorageType {
_ExchangeKeyTempStorage key_storage;
_ExchangeValueTempStorage value_storage;
} ExchKeyValueTempStorageType;
using _ExchangeType = typename std::conditional<KEYS_ONLY, ExchKeyTempStorageType, ExchKeyValueTempStorageType>::type;
struct _TempStorage {
typename BlockScanT::TempStorage scan_storage;
union {
typename BinBlockLoad::TempStorage load_storage;
typename BinBlockStore::TempStorage store_storage;
} load_store;
union {
int shared_bins[RADIX_SIZE];
_ExchangeType exchange_storage;
};
int share_target_k;
int share_bucket_id;
int share_prev_count;
};
public:
struct TempStorage : cub::Uninitialized<_TempStorage> {
};
__device__ __forceinline__ BlockRadixTopKRegister(TempStorage &temp_storage)
: temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){};
__device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
const int k, const int valid_count) {
TopKGenRank(keys, k, valid_count);
int is_valid[ITEMS_PER_THREAD];
GenValidArray(is_valid, k);
BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged(keys, keys, ranks_, is_valid);
cub::CTA_SYNC();
}
__device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], ValueT (&values)[ITEMS_PER_THREAD],
const int k, const int valid_count) {
TopKGenRank(keys, k, valid_count);
int is_valid[ITEMS_PER_THREAD];
GenValidArray(is_valid, k);
BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged(keys, keys, ranks_, is_valid);
cub::CTA_SYNC();
BlockExchangeValue{temp_storage_.exchange_storage.value_storage}.ScatterToStripedFlagged(values, values, ranks_, is_valid);
cub::CTA_SYNC();
}
private:
__device__ __forceinline__ void TopKGenRank(KeyT (&keys)[ITEMS_PER_THREAD], const int k, const int valid_count) {
assert(k <= BLOCK_SIZE * ITEMS_PER_THREAD);
assert(k <= valid_count);
if (k == valid_count) return;
UnsignedBits(&unsigned_keys)[ITEMS_PER_THREAD] = reinterpret_cast<UnsignedBits(&)[ITEMS_PER_THREAD]>(keys);
search_mask_ = 0;
top_k_mask_ = 0;
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
int idx = KEY * BLOCK_SIZE + tid_;
unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY];
if (idx < valid_count) search_mask_ |= (1U << KEY);
}
int target_k = k;
int prefix_k = 0;
for (int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) {
UpdateSharedBins(unsigned_keys, digit_pos, prefix_k);
InclusiveScanBins();
UpdateTopK(unsigned_keys, digit_pos, target_k, prefix_k, digit_pos == 0);
if (target_k == 0) break;
}
#pragma unroll
for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY];
unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
}
}
__device__ __forceinline__ void GenValidArray(int (&is_valid)[ITEMS_PER_THREAD], int k) {
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
if ((top_k_mask_ & (1U << KEY)) && ranks_[KEY] < k) {
is_valid[KEY] = 1;
} else {
is_valid[KEY] = 0;
}
}
}
__device__ __forceinline__ void UpdateSharedBins(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
int digit_pos, int prefix_k) {
for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) {
temp_storage_.shared_bins[id] = 0;
}
cub::CTA_SYNC();
//#define USE_MATCH
#ifdef USE_MATCH
int lane_mask = cub::LaneMaskLt();
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
bool is_search = search_mask_ & (1U << KEY);
int bucket_idx = -1;
if (is_search) {
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(unsigned_keys[KEY], digit_pos, RADIX_BITS);
bucket_idx = (int) digit_in_radix;
}
int warp_match_mask = __match_any_sync(0xffffffff, bucket_idx);
int same_count = __popc(warp_match_mask);
int idx_in_same_bucket = __popc(warp_match_mask & lane_mask);
int same_bucket_root_lane = __ffs(warp_match_mask) - 1;
int same_bucket_start_idx;
if (idx_in_same_bucket == 0 && is_search) {
same_bucket_start_idx = atomicAdd(&temp_storage_.shared_bins[bucket_idx], same_count);
}
same_bucket_start_idx = __shfl_sync(0xffffffff, same_bucket_start_idx, same_bucket_root_lane, 32);
if (is_search) {
ranks_[KEY] = same_bucket_start_idx + idx_in_same_bucket + prefix_k;
}
}
#else
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
bool is_search = search_mask_ & (1U << KEY);
int bucket_idx = -1;
if (is_search) {
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(unsigned_keys[KEY], digit_pos, RADIX_BITS);
bucket_idx = (int) digit_in_radix;
ranks_[KEY] = atomicAdd(&temp_storage_.shared_bins[bucket_idx], 1) + prefix_k;
}
}
#endif
cub::CTA_SYNC();
}
__device__ __forceinline__ void InclusiveScanBins() {
int items[SCAN_ITEMS_PER_THREAD];
BinBlockLoad(temp_storage_.load_store.load_storage).Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0);
cub::CTA_SYNC();
BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items);
cub::CTA_SYNC();
BinBlockStore(temp_storage_.load_store.store_storage).Store(temp_storage_.shared_bins, items, RADIX_SIZE);
cub::CTA_SYNC();
}
__device__ __forceinline__ void UpdateTopK(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
int digit_pos,
int &target_k,
int &prefix_k,
bool mark_equal) {
for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) {
int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1];
int cur_count = temp_storage_.shared_bins[idx];
if (prev_count <= target_k && cur_count > target_k) {
temp_storage_.share_target_k = target_k - prev_count;
temp_storage_.share_bucket_id = idx;
temp_storage_.share_prev_count = prev_count;
}
}
cub::CTA_SYNC();
target_k = temp_storage_.share_target_k;
prefix_k += temp_storage_.share_prev_count;
int target_bucket_id = temp_storage_.share_bucket_id;
#pragma unroll
for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) {
if (search_mask_ & (1U << KEY)) {
UnsignedBits digit_in_radix = cub::BFE<UnsignedBits>(unsigned_keys[KEY], digit_pos, RADIX_BITS);
if (digit_in_radix < target_bucket_id) {
top_k_mask_ |= (1U << KEY);
search_mask_ &= ~(1U << KEY);
} else if (digit_in_radix > target_bucket_id) {
search_mask_ &= ~(1U << KEY);
} else {
if (mark_equal) top_k_mask_ |= (1U << KEY);
}
if (digit_in_radix <= target_bucket_id) {
int prev_count = (digit_in_radix == 0) ? 0 : temp_storage_.shared_bins[digit_in_radix - 1];
ranks_[KEY] += prev_count;
}
}
}
cub::CTA_SYNC();
}
_TempStorage &temp_storage_;
int tid_;
int ranks_[ITEMS_PER_THREAD];
unsigned int search_mask_;
unsigned int top_k_mask_;
};
}; // end namespace framework
}; // end namespace paddle
#endif
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef __NVCC__
#include <cuda_runtime_api.h> // NOLINT
#endif
class RandomNumGen {
public:
__host__ __device__ __forceinline__ RandomNumGen(int gid, unsigned long long seed) {
next_random = seed + gid;
next_random ^= next_random >> 33U;
next_random *= 0xff51afd7ed558ccdUL;
next_random ^= next_random >> 33U;
next_random *= 0xc4ceb9fe1a85ec53UL;
next_random ^= next_random >> 33U;
}
__host__ __device__ __forceinline__ ~RandomNumGen() = default;
__host__ __device__ __forceinline__ void SetSeed(int seed) {
next_random = seed;
NextValue();
}
__host__ __device__ __forceinline__ unsigned long long SaveState() const {
return next_random;
}
__host__ __device__ __forceinline__ void LoadState(unsigned long long state) {
next_random = state;
}
__host__ __device__ __forceinline__ int Random() {
int ret_value = (int) (next_random & 0x7fffffffULL);
NextValue();
return ret_value;
}
__host__ __device__ __forceinline__ int RandomMod(int mod) {
return Random() % mod;
}
__host__ __device__ __forceinline__ int64_t Random64() {
int64_t ret_value = (next_random & 0x7FFFFFFFFFFFFFFFLL);
NextValue();
return ret_value;
}
__host__ __device__ __forceinline__ int64_t RandomMod64(int64_t mod) {
return Random64() % mod;
}
__host__ __device__ __forceinline__ float RandomUniformFloat(float max = 1.0f, float min = 0.0f) {
int value = (int) (next_random & 0xffffff);
auto ret_value = (float) value;
ret_value /= 0xffffffL;
ret_value *= (max - min);
ret_value += min;
NextValue();
return ret_value;
}
__host__ __device__ __forceinline__ bool RandomBool(float true_prob) {
float value = RandomUniformFloat();
return value <= true_prob;
}
__host__ __device__ __forceinline__ void NextValue() {
//next_random = next_random * (unsigned long long)0xc4ceb9fe1a85ec53UL + generator_id;
//next_random = next_random * (unsigned long long)25214903917ULL + 11;
next_random = next_random * (unsigned long long) 13173779397737131ULL + 1023456798976543201ULL;
}
private:
unsigned long long next_random = 1;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphWeightedSampleNeighborsKernel(
const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& edge_weight,
const DenseTensor& x,
const paddle::optional<DenseTensor>& eids,
int sample_size,
bool return_eids,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids);
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestWeightedSampleNeighbors(unittest.TestCase):
def setUp(self):
num_nodes = 20
edges = np.random.randint(num_nodes, size=(100, 2))
edges = np.unique(edges, axis=0)
self.edges_id = np.arange(0, len(edges)).astype("int64")
sorted_edges = edges[np.argsort(edges[:, 1])]
# Calculate dst index cumsum counts, also means colptr
dst_count = np.zeros(num_nodes)
dst_src_dict = {}
for dst in range(0, num_nodes):
true_index = sorted_edges[:, 1] == dst
dst_count[dst] = np.sum(true_index)
dst_src_dict[dst] = sorted_edges[:, 0][true_index]
dst_count = dst_count.astype("int64")
colptr = np.cumsum(dst_count)
colptr = np.insert(colptr, 0, 0)
self.row = sorted_edges[:, 0].astype("int64")
self.colptr = colptr.astype("int64")
self.nodes = np.unique(np.random.randint(num_nodes, size=5)).astype(
"int64"
)
self.weight = np.ones(self.row.shape[0]).astype("float32")
self.sample_size = 5
self.dst_src_dict = dst_src_dict
def test_sample_result(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
weight = paddle.to_tensor(self.weight)
out_neighbors, out_count = paddle.geometric.weighted_sample_neighbors(
row, colptr, weight, nodes, sample_size=self.sample_size
)
out_count_cumsum = paddle.cumsum(out_count)
for i in range(len(out_count)):
if i == 0:
neighbors = out_neighbors[0 : out_count_cumsum[i]]
else:
neighbors = out_neighbors[
out_count_cumsum[i - 1] : out_count_cumsum[i]
]
# Ensure the correct sample size.
self.assertTrue(
out_count[i] == self.sample_size
or out_count[i] == len(self.dst_src_dict[self.nodes[i]])
)
# Ensure no repetitive sample neighbors.
self.assertTrue(
neighbors.shape[0] == paddle.unique(neighbors).shape[0]
)
# Ensure the correct sample neighbors.
in_neighbors = np.isin(
neighbors.numpy(), self.dst_src_dict[self.nodes[i]]
)
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_sample_result_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype
)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype
)
weight = paddle.static.data(
name="weight", shape=self.weight.shape, dtype=self.weight.dtype
)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype
)
(
out_neighbors,
out_count,
) = paddle.geometric.weighted_sample_neighbors(
row, colptr, weight, nodes, sample_size=self.sample_size
)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(
feed={
'row': self.row,
'colptr': self.colptr,
'weight': self.weight,
'nodes': self.nodes,
},
fetch_list=[out_neighbors, out_count],
)
out_neighbors, out_count = ret
out_count_cumsum = np.cumsum(out_count)
out_neighbors = np.split(out_neighbors, out_count_cumsum)[:-1]
for neighbors, node, count in zip(
out_neighbors, self.nodes, out_count
):
self.assertTrue(
count == self.sample_size
or count == len(self.dst_src_dict[node])
)
self.assertTrue(
neighbors.shape[0] == np.unique(neighbors).shape[0]
)
in_neighbors = np.isin(neighbors, self.dst_src_dict[node])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_raise_errors(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
weight = paddle.to_tensor(self.weight)
nodes = paddle.to_tensor(self.nodes)
def check_eid_error():
paddle.geometric.weighted_sample_neighbors(
row,
colptr,
weight,
nodes,
sample_size=self.sample_size,
return_eids=True,
)
self.assertRaises(ValueError, check_eid_error)
def test_sample_result_with_eids(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
weight = paddle.to_tensor(self.weight)
nodes = paddle.to_tensor(self.nodes)
eids = paddle.to_tensor(self.edges_id)
(
out_neighbors,
out_count,
out_eids,
) = paddle.geometric.weighted_sample_neighbors(
row,
colptr,
weight,
nodes,
eids=eids,
sample_size=self.sample_size,
return_eids=True,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype
)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype
)
weight = paddle.static.data(
name="weight", shape=self.weight.shape, dtype=self.weight.dtype
)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype
)
eids = paddle.static.data(
name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype
)
(
out_neighbors,
out_count,
out_eids,
) = paddle.geometric.weighted_sample_neighbors(
row,
colptr,
weight,
nodes,
sample_size=self.sample_size,
eids=eids,
return_eids=True,
)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(
feed={
'row': self.row,
'colptr': self.colptr,
'weight': self.weight,
'nodes': self.nodes,
'eids': self.edges_id,
},
fetch_list=[out_neighbors, out_count, out_eids],
)
if __name__ == "__main__":
unittest.main()
......@@ -22,6 +22,7 @@ from .math import segment_max # noqa: F401
from .reindex import reindex_graph # noqa: F401
from .reindex import reindex_heter_graph # noqa: F401
from .sampling import sample_neighbors # noqa: F401
from .sampling import weighted_sample_neighbors # noqa: F401
__all__ = [
'send_u_recv',
......@@ -34,4 +35,5 @@ __all__ = [
'reindex_graph',
'reindex_heter_graph',
'sample_neighbors',
'weighted_sample_neighbors',
]
......@@ -13,5 +13,6 @@
# limitations under the License.
from .neighbors import sample_neighbors # noqa: F401
from .neighbors import weighted_sample_neighbors # noqa: F401
__all__ = []
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _legacy_C_ops
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
__all__ = []
......@@ -170,3 +170,151 @@ def sample_neighbors(
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
def weighted_sample_neighbors(
row,
colptr,
edge_weight,
input_nodes,
sample_size=-1,
eids=None,
return_eids=False,
name=None,
):
"""
Graph Weighted Sample Neighbors API.
This API is mainly used in Graph Learning domain, and the main purpose is to
provide high performance of graph weighted-sampling method. For example, we get the
CSC(Compressed Sparse Column) format of the input graph edges as `row` and
`colptr`, so as to convert graph data into a suitable format for sampling, and the
input `edge_weight` should also match the CSC format. Besides, `input_nodes` means
the nodes we need to sample neighbors, and `sample_sizes` means the number of neighbors
and number of layers we want to sample. This API will finally return the weighted sampled
neighbors, and the probability of being selected as a neighbor is related to its weight,
with higher weight and higher probability.
Args:
row (Tensor): One of the components of the CSC format of the input graph, and
the shape should be [num_edges, 1] or [num_edges]. The available
data type is int32, int64.
colptr (Tensor): One of the components of the CSC format of the input graph,
and the shape should be [num_nodes + 1, 1] or [num_nodes + 1].
The data type should be the same with `row`.
edge_weight (Tensor): The edge weight of the CSC format graph edges. And the shape
should be [num_edges, 1] or [num_edges]. The available data
type is float32.
input_nodes (Tensor): The input nodes we need to sample neighbors for, and the
data type should be the same with `row`.
sample_size (int, optional): The number of neighbors we need to sample. Default value is -1,
which means returning all the neighbors of the input nodes.
eids (Tensor, optional): The eid information of the input graph. If return_eids is True,
then `eids` should not be None. The data type should be the
same with `row`. Default is None.
return_eids (bool, optional): Whether to return eid information of sample edges. Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- out_neighbors (Tensor), the sample neighbors of the input nodes.
- out_count (Tensor), the number of sampling neighbors of each input node, and the shape
should be the same with `input_nodes`.
- out_eids (Tensor), if `return_eids` is True, we will return the eid information of the
sample edges.
Examples:
.. code-block:: python
import paddle
# edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4),
# (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8)
row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7]
colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13]
weight = [0.1, 0.5, 0.2, 0.5, 0.9, 1.9, 2.0, 2.1, 0.01, 0.9, 0,12, 0.59, 0.67]
nodes = [0, 8, 1, 2]
sample_size = 2
row = paddle.to_tensor(row, dtype="int64")
colptr = paddle.to_tensor(colptr, dtype="int64")
weight = paddle.to_tensor(weight, dtype="float32")
nodes = paddle.to_tensor(nodes, dtype="int64")
out_neighbors, out_count = paddle.geometric.weighted_sample_neighbors(row, colptr, weight, nodes, sample_size=sample_size)
"""
if return_eids:
if eids is None:
raise ValueError(
"`eids` should not be None if `return_eids` is True."
)
if in_dygraph_mode():
(
out_neighbors,
out_count,
out_eids,
) = _C_ops.weighted_sample_neighbors(
row,
colptr,
edge_weight,
input_nodes,
eids,
sample_size,
return_eids,
)
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
check_variable_and_dtype(
row, "row", ("int32", "int64"), "weighted_sample_neighbors"
)
check_variable_and_dtype(
colptr, "colptr", ("int32", "int64"), "weighted_sample_neighbors"
)
check_variable_and_dtype(
edge_weight,
"edge_weight",
("float32"),
"weighted_sample_neighbors",
)
check_variable_and_dtype(
input_nodes,
"input_nodes",
("int32", "int64"),
"weighted_sample_neighbors",
)
if return_eids:
check_variable_and_dtype(
eids, "eids", ("int32", "int64"), "weighted_sample_neighbors"
)
helper = LayerHelper("weighted_sample_neighbors", **locals())
out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype)
out_count = helper.create_variable_for_type_inference(dtype=row.dtype)
out_eids = helper.create_variable_for_type_inference(dtype=row.dtype)
helper.append_op(
type="weighted_sample_neighbors",
inputs={
"row": row,
"colptr": colptr,
"edge_weight": edge_weight,
"input_nodes": input_nodes,
"eids": eids if return_eids else None,
},
outputs={
"out_neighbors": out_neighbors,
"out_count": out_count,
"out_eids": out_eids,
},
attrs={
"sample_size": sample_size,
"return_eids": return_eids,
},
)
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册