未验证 提交 b0398c8e 编写于 作者: S Siming Dai 提交者: GitHub

Add graph apis (#40809)

* Add graph_reindex API

* add graph_sample_neighbors api

* Add buffer

* delete VLOG

* delete thrust::copy for output

* add ShareDataWith

* delete graph_reindex hashtable output

* add graph_reindex dispensable

* add reindex unittest, move memset to cuda kernel, change api

* fix conflict

* add reindex buffer for gpu version note

* fix conflicts for op_func_generator

* Add fisher_yates sampling, add dispensable, change infermeta

* add dtype for edge_id

* fix rocm ci and static check ci

* add unittest

* fix unittest

* fix unittest

* fix bug
上级 36f97cdc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class GraphReindexOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class GraphReindexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The destination nodes of the input graph.");
AddInput("Neighbors", "The neighbor nodes of the destination nodes `X`.");
AddInput("Count", "The number of neighbor nodes of each destination node.");
// Note(daisiming): If using buffer hashtable, we must ensure the number of
// nodes of the input graph should be no larger than maximum(int32).
AddInput("HashTable_Value",
"One of the buffer tensor of hashtable for reindex")
.AsDispensable();
AddInput("HashTable_Index",
"One of the buffer tensor of hashtable for reindex")
.AsDispensable();
AddAttr<bool>("flag_buffer_hashtable",
"Define whether using the buffer hashtable.")
.SetDefault(false);
AddOutput("Reindex_Src",
"The source node index of graph edges after reindex.");
AddOutput("Reindex_Dst",
"The destination node index of graph edges after reindex.");
AddOutput("Out_Nodes", "The original index of graph nodes before reindex");
AddComment(R"DOC(
Graph Reindex operator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(graph_reindex, GraphReindexInferShapeFunctor,
PD_INFER_META(phi::GraphReindexInferMeta));
REGISTER_OPERATOR(
graph_reindex, ops::GraphReindexOP, ops::GraphReindexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GraphReindexInferShapeFunctor);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class GraphSampleNeighborsOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Row"),
ctx.device_context());
}
};
class GraphSampleNeighborsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Row",
"One of the components of the CSC format of the input graph.");
AddInput("Col_Ptr",
"One of the components of the CSC format of the input graph.");
AddInput("X", "The input center nodes index tensor.");
AddInput("Eids", "The edge ids of the input graph.").AsDispensable();
AddInput("Perm_Buffer", "Permutation buffer for fisher-yates sampling.")
.AsDispensable();
AddOutput("Out", "The neighbors of input nodes X after sampling.");
AddOutput("Out_Count",
"The number of sample neighbors of input nodes respectively.");
AddOutput("Out_Eids", "The eids of the sample edges");
AddAttr<int>(
"sample_size", "The sample size of graph sample neighbors method. ",
"Set default value as -1, means return all neighbors of nodes.")
.SetDefault(-1);
AddAttr<bool>("return_eids",
"Whether to return the eid of the sample edges.")
.SetDefault(false);
AddAttr<bool>("flag_perm_buffer",
"Using the permutation for fisher-yates sampling in GPU"
"Set default value as false, means not using it.")
.SetDefault(false);
AddComment(R"DOC(
Graph Learning Sampling Neighbors operator, for graphsage sampling method.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(graph_sample_neighbors,
GraphSampleNeighborsInferShapeFunctor,
PD_INFER_META(phi::GraphSampleNeighborsInferMeta));
REGISTER_OPERATOR(
graph_sample_neighbors, ops::GraphSampleNeighborsOP,
ops::GraphSampleNeighborsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
GraphSampleNeighborsInferShapeFunctor);
......@@ -105,6 +105,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"linear_chain_crf", {"Emission", "Transition", "Label", "Length"}},
{"crf_decoding", {"Emission", "Transition", "Label", "Length"}},
{"chunk_eval", {"Inference", "Label", "SeqLength"}},
{"graph_reindex",
{"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"}},
{"graph_sample_neighbors", {"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......
......@@ -1775,6 +1775,103 @@ void WhereInferMeta(const MetaTensor& condition,
out->share_meta(x);
}
void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
paddle::optional<const MetaTensor&> hashtable_value,
paddle::optional<const MetaTensor&> hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes) {
auto GraphReindexShapeCheck = [](const phi::DDim& dims,
std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GraphReindexShapeCheck(x.dims(), "X");
GraphReindexShapeCheck(neighbors.dims(), "Neighbors");
GraphReindexShapeCheck(count.dims(), "Count");
if (flag_buffer_hashtable) {
GraphReindexShapeCheck(hashtable_value->dims(), "HashTable_Value");
GraphReindexShapeCheck(hashtable_index->dims(), "HashTable_Index");
}
reindex_src->set_dims({-1});
reindex_src->set_dtype(neighbors.dtype());
reindex_dst->set_dims({-1});
reindex_dst->set_dtype(neighbors.dtype());
out_nodes->set_dims({-1});
out_nodes->set_dtype(x.dtype());
}
void GraphSampleNeighborsInferMeta(
const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
paddle::optional<const MetaTensor&> eids,
paddle::optional<const MetaTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids) {
// GSN: GraphSampleNeighbors
auto GSNShapeCheck = [](const phi::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(
dims[1],
1,
phi::errors::InvalidArgument("The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name,
dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(),
1,
phi::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name,
dims.size()));
}
};
GSNShapeCheck(row.dims(), "Row");
GSNShapeCheck(col_ptr.dims(), "Col_Ptr");
GSNShapeCheck(x.dims(), "X");
if (return_eids) {
GSNShapeCheck(eids->dims(), "Eids");
out_eids->set_dims({-1});
out_eids->set_dtype(row.dtype());
}
if (flag_perm_buffer) {
GSNShapeCheck(perm_buffer->dims(), "Perm_Buffer");
}
out->set_dims({-1});
out->set_dtype(row.dtype());
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
void Yolov3LossInferMeta(const MetaTensor& x,
const MetaTensor& gt_box,
const MetaTensor& gt_label,
......
......@@ -265,6 +265,29 @@ void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& y,
MetaTensor* out);
void GraphReindexInferMeta(const MetaTensor& x,
const MetaTensor& neighbors,
const MetaTensor& count,
paddle::optional<const MetaTensor&> hashtable_value,
paddle::optional<const MetaTensor&> hashtable_index,
bool flag_buffer_hashtable,
MetaTensor* reindex_src,
MetaTensor* reindex_dst,
MetaTensor* out_nodes);
void GraphSampleNeighborsInferMeta(
const MetaTensor& row,
const MetaTensor& col_ptr,
const MetaTensor& x,
paddle::optional<const MetaTensor&> eids,
paddle::optional<const MetaTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
MetaTensor* out,
MetaTensor* out_count,
MetaTensor* out_eids);
void Yolov3LossInferMeta(const MetaTensor& x,
const MetaTensor& gt_box,
const MetaTensor& gt_label,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_map>
#include <vector>
#include "paddle/phi/kernels/graph_reindex_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& neighbors,
const DenseTensor& count,
paddle::optional<const DenseTensor&> hashtable_value,
paddle::optional<const DenseTensor&> hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes) {
const T* x_data = x.data<T>();
const T* neighbors_data = neighbors.data<T>();
const int* count_data = count.data<int>();
const int bs = x.dims()[0];
const int num_edges = neighbors.dims()[0];
std::unordered_map<T, T> node_map;
std::vector<T> unique_nodes;
int reindex_id = 0;
for (int i = 0; i < bs; i++) {
T node = x_data[i];
unique_nodes.emplace_back(node);
node_map[node] = reindex_id++;
}
// Reindex Src
std::vector<T> src(num_edges);
std::vector<T> dst(num_edges);
for (int i = 0; i < num_edges; i++) {
T node = neighbors_data[i];
if (node_map.find(node) == node_map.end()) {
unique_nodes.emplace_back(node);
node_map[node] = reindex_id++;
}
src[i] = node_map[node];
}
// Reindex Dst
int cnt = 0;
for (int i = 0; i < bs; i++) {
for (int j = 0; j < count_data[i]; j++) {
T node = x_data[i];
dst[cnt++] = node_map[node];
}
}
reindex_src->Resize({num_edges});
T* reindex_src_data = dev_ctx.template Alloc<T>(reindex_src);
std::copy(src.begin(), src.end(), reindex_src_data);
reindex_dst->Resize({num_edges});
T* reindex_dst_data = dev_ctx.template Alloc<T>(reindex_dst);
std::copy(dst.begin(), dst.end(), reindex_dst_data);
out_nodes->Resize({static_cast<int>(unique_nodes.size())});
T* out_nodes_data = dev_ctx.template Alloc<T>(out_nodes);
std::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data);
}
} // namespace phi
PD_REGISTER_KERNEL(
graph_reindex, CPU, ALL_LAYOUT, phi::GraphReindexKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/phi/kernels/graph_sample_neighbors_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <class bidiiter>
void SampleUniqueNeighbors(
bidiiter begin,
bidiiter end,
int num_samples,
std::mt19937& rng,
std::uniform_int_distribution<int>& dice_distribution) {
int left_num = std::distance(begin, end);
for (int i = 0; i < num_samples; i++) {
bidiiter r = begin;
int random_step = dice_distribution(rng) % left_num;
std::advance(r, random_step);
std::swap(*begin, *r);
++begin;
--left_num;
}
}
template <typename T>
void SampleNeighbors(const T* row,
const T* col_ptr,
const T* input,
std::vector<T>* output,
std::vector<int>* output_count,
int sample_size,
int bs) {
// Allocate the memory of output
// Collect the neighbors size
std::vector<std::vector<T>> out_src_vec;
// `sample_cumsum_sizes` record the start position and end position
// after sampling.
std::vector<int> sample_cumsum_sizes(bs + 1);
// `total_neighbors` the size of output after sample.
int total_neighbors = 0;
sample_cumsum_sizes[0] = total_neighbors;
for (int i = 0; i < bs; i++) {
T node = input[i];
int cap = col_ptr[node + 1] - col_ptr[node];
int k = cap > sample_size ? sample_size : cap;
total_neighbors += k;
sample_cumsum_sizes[i + 1] = total_neighbors;
std::vector<T> out_src;
out_src.resize(cap);
out_src_vec.emplace_back(out_src);
}
output_count->resize(bs);
output->resize(total_neighbors);
std::random_device rd;
std::mt19937 rng{rd()};
std::uniform_int_distribution<int> dice_distribution(
0, std::numeric_limits<int>::max());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Sample the neighbors in parallelism.
for (int i = 0; i < bs; i++) {
T node = input[i];
T begin = col_ptr[node], end = col_ptr[node + 1];
int cap = end - begin;
if (sample_size < cap) {
std::copy(row + begin, row + end, out_src_vec[i].begin());
// TODO(daisiming): Check whether is correct.
SampleUniqueNeighbors(out_src_vec[i].begin(),
out_src_vec[i].end(),
sample_size,
rng,
dice_distribution);
*(output_count->data() + i) = sample_size;
} else {
std::copy(row + begin, row + end, out_src_vec[i].begin());
*(output_count->data() + i) = cap;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Copy the results parallelism
for (int i = 0; i < bs; i++) {
int k = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i];
std::copy(out_src_vec[i].begin(),
out_src_vec[i].begin() + k,
output->data() + sample_cumsum_sizes[i]);
}
}
template <typename T, typename Context>
void GraphSampleNeighborsKernel(
const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& x,
paddle::optional<const DenseTensor&> eids,
paddle::optional<const DenseTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids) {
const T* row_data = row.data<T>();
const T* col_ptr_data = col_ptr.data<T>();
const T* x_data = x.data<T>();
int bs = x.dims()[0];
std::vector<T> output;
std::vector<int> output_count;
SampleNeighbors<T>(
row_data, col_ptr_data, x_data, &output, &output_count, sample_size, bs);
out->Resize({static_cast<int>(output.size())});
T* out_data = dev_ctx.template Alloc<T>(out);
std::copy(output.begin(), output.end(), out_data);
out_count->Resize({bs});
int* out_count_data = dev_ctx.template Alloc<int>(out_count);
std::copy(output_count.begin(), output_count.end(), out_count_data);
}
} // namespace phi
PD_REGISTER_KERNEL(graph_sample_neighbors,
CPU,
ALL_LAYOUT,
phi::GraphSampleNeighborsKernel,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/graph_reindex_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
namespace phi {
template <typename T>
inline __device__ size_t Hash(T id, int64_t size) {
return id % size;
}
template <typename T>
inline __device__ bool AttemptInsert(
size_t pos, T id, int index, T* keys, int* key_index) {
if (sizeof(T) == 4) {
const T key = atomicCAS(reinterpret_cast<unsigned int*>(&keys[pos]),
static_cast<unsigned int>(-1),
static_cast<unsigned int>(id));
if (key == -1 || key == id) {
atomicMin(reinterpret_cast<unsigned int*>(&key_index[pos]), // NOLINT
static_cast<unsigned int>(index)); // NOLINT
return true;
} else {
return false;
}
} else if (sizeof(T) == 8) {
const T key = atomicCAS(
reinterpret_cast<unsigned long long int*>(&keys[pos]), // NOLINT
static_cast<unsigned long long int>(-1), // NOLINT
static_cast<unsigned long long int>(id)); // NOLINT
if (key == -1 || key == id) {
atomicMin(reinterpret_cast<unsigned int*>(&key_index[pos]), // NOLINT
static_cast<unsigned int>(index)); // NOLINT
return true;
} else {
return false;
}
}
}
template <typename T>
inline __device__ void Insert(
T id, int index, int64_t size, T* keys, int* key_index) {
size_t pos = Hash(id, size);
size_t delta = 1;
while (!AttemptInsert(pos, id, index, keys, key_index)) {
pos = Hash(pos + delta, size);
delta += 1;
}
}
template <typename T>
inline __device__ int64_t Search(T id, const T* keys, int64_t size) {
int64_t pos = Hash(id, size);
int64_t delta = 1;
while (keys[pos] != id) {
pos = Hash(pos + delta, size);
delta += 1;
}
return pos;
}
template <typename T>
__global__ void BuildHashTable(
const T* items, int num_items, int64_t size, T* keys, int* key_index) {
CUDA_KERNEL_LOOP(index, num_items) {
Insert(items[index], index, size, keys, key_index);
}
}
template <typename T>
__global__ void BuildHashTable(const T* items, int num_items, int* key_index) {
CUDA_KERNEL_LOOP(index, num_items) {
atomicMin(
reinterpret_cast<unsigned int*>(&key_index[items[index]]), // NOLINT
static_cast<unsigned int>(index)); // NOLINT
}
}
template <typename T>
__global__ void ResetHashTable(const T* items,
int num_items,
int* key_index,
int* values) {
CUDA_KERNEL_LOOP(index, num_items) {
key_index[items[index]] = -1;
values[items[index]] = -1;
}
}
template <typename T>
__global__ void GetItemIndexCount(const T* items,
int* item_count,
int num_items,
int64_t size,
const T* keys,
int* key_index) {
CUDA_KERNEL_LOOP(i, num_items) {
int64_t pos = Search(items[i], keys, size);
if (key_index[pos] == i) {
item_count[i] = 1;
}
}
}
template <typename T>
__global__ void GetItemIndexCount(const T* items,
int* item_count,
int num_items,
int* key_index) {
CUDA_KERNEL_LOOP(i, num_items) {
if (key_index[items[i]] == i) {
item_count[i] = 1;
}
}
}
template <typename T>
__global__ void FillUniqueItems(const T* items,
int num_items,
int64_t size,
T* unique_items,
const int* item_count,
const T* keys,
int* values,
int* key_index) {
CUDA_KERNEL_LOOP(i, num_items) {
int64_t pos = Search(items[i], keys, size);
if (key_index[pos] == i) {
values[pos] = item_count[i];
unique_items[item_count[i]] = items[i];
}
}
}
template <typename T>
__global__ void FillUniqueItems(const T* items,
int num_items,
T* unique_items,
const int* item_count,
int* values,
int* key_index) {
CUDA_KERNEL_LOOP(i, num_items) {
if (key_index[items[i]] == i) {
values[items[i]] = item_count[i];
unique_items[item_count[i]] = items[i];
}
}
}
template <typename T>
__global__ void ReindexSrcOutput(T* src_output,
int num_items,
int64_t size,
const T* keys,
const int* values) {
CUDA_KERNEL_LOOP(i, num_items) {
int64_t pos = Search(src_output[i], keys, size);
src_output[i] = values[pos];
}
}
template <typename T>
__global__ void ReindexSrcOutput(T* src_output,
int num_items,
const int* values) {
CUDA_KERNEL_LOOP(i, num_items) { src_output[i] = values[src_output[i]]; }
}
template <typename T>
__global__ void ReindexInputNodes(const T* nodes,
int num_items,
T* reindex_nodes,
int64_t size,
const T* keys,
const int* values) {
CUDA_KERNEL_LOOP(i, num_items) {
int64_t pos = Search(nodes[i], keys, size);
reindex_nodes[i] = values[pos];
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include "paddle/phi/kernels/gpu/graph_reindex_funcs.h"
#include "paddle/phi/kernels/graph_reindex_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
constexpr int WARP_SIZE = 32;
template <typename T, typename Context>
void FillHashTable(const Context& dev_ctx,
const T* input,
int num_input,
int64_t len_hashtable,
thrust::device_vector<T>* unique_items,
T* keys,
int* values,
int* key_index) {
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (num_input + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
// Insert data into keys and values.
BuildHashTable<T><<<grid, block, 0, dev_ctx.stream()>>>(
input, num_input, len_hashtable, keys, key_index);
// Get item index count.
thrust::device_vector<int> item_count(num_input + 1, 0);
GetItemIndexCount<T><<<grid, block, 0, dev_ctx.stream()>>>(
input,
thrust::raw_pointer_cast(item_count.data()),
num_input,
len_hashtable,
keys,
key_index);
thrust::exclusive_scan(
item_count.begin(), item_count.end(), item_count.begin());
size_t total_unique_items = item_count[num_input];
unique_items->resize(total_unique_items);
// Get unique items
FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(
input,
num_input,
len_hashtable,
thrust::raw_pointer_cast(unique_items->data()),
thrust::raw_pointer_cast(item_count.data()),
keys,
values,
key_index);
}
template <typename T, typename Context>
void FillBufferHashTable(const Context& dev_ctx,
const T* input,
int num_input,
thrust::device_vector<T>* unique_items,
int* values,
int* key_index) {
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (num_input + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
// Insert data.
BuildHashTable<T><<<grid, block, 0, dev_ctx.stream()>>>(
input, num_input, key_index);
// Get item index count.
thrust::device_vector<int> item_count(num_input + 1, 0);
GetItemIndexCount<T><<<grid, block, 0, dev_ctx.stream()>>>(
input, thrust::raw_pointer_cast(item_count.data()), num_input, key_index);
thrust::exclusive_scan(
item_count.begin(), item_count.end(), item_count.begin());
size_t total_unique_items = item_count[num_input];
unique_items->resize(total_unique_items);
// Get unique items
FillUniqueItems<T><<<grid, block, 0, dev_ctx.stream()>>>(
input,
num_input,
thrust::raw_pointer_cast(unique_items->data()),
thrust::raw_pointer_cast(item_count.data()),
values,
key_index);
}
template <typename T, typename Context>
void ResetBufferHashTable(const Context& dev_ctx,
const T* input,
int num_input,
thrust::device_vector<T>* unique_items,
int* values,
int* key_index) {
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (unique_items->size() + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
ResetHashTable<T><<<grid, block, 0, dev_ctx.stream()>>>(
thrust::raw_pointer_cast(unique_items->data()),
unique_items->size(),
key_index,
values);
}
template <typename T, typename Context>
void Reindex(const Context& dev_ctx,
const T* inputs,
thrust::device_ptr<T> src_outputs,
thrust::device_vector<T>* out_nodes,
int num_inputs,
int num_edges) {
out_nodes->resize(num_inputs + num_edges);
thrust::copy(inputs, inputs + num_inputs, out_nodes->begin());
thrust::copy(
src_outputs, src_outputs + num_edges, out_nodes->begin() + num_inputs);
thrust::device_vector<T> unique_nodes;
unique_nodes.clear();
// Fill hash table
int64_t num = out_nodes->size();
int64_t log_num = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
int64_t table_size = log_num << 1;
T* keys;
int *values, *key_index;
#ifdef PADDLE_WITH_HIP
hipMalloc(&keys, table_size * sizeof(T));
hipMalloc(&values, table_size * sizeof(int));
hipMalloc(&key_index, table_size * sizeof(int));
hipMemset(keys, -1, table_size * sizeof(T));
hipMemset(values, -1, table_size * sizeof(int));
hipMemset(key_index, -1, table_size * sizeof(int));
#else
cudaMalloc(&keys, table_size * sizeof(T));
cudaMalloc(&values, table_size * sizeof(int));
cudaMalloc(&key_index, table_size * sizeof(int));
cudaMemset(keys, -1, table_size * sizeof(T));
cudaMemset(values, -1, table_size * sizeof(int));
cudaMemset(key_index, -1, table_size * sizeof(int));
#endif
FillHashTable<T, Context>(dev_ctx,
thrust::raw_pointer_cast(out_nodes->data()),
out_nodes->size(),
table_size,
&unique_nodes,
keys,
values,
key_index);
out_nodes->resize(unique_nodes.size());
thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes->begin());
// Fill outputs with reindex result.
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (num_edges + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
ReindexSrcOutput<T><<<grid, block, 0, dev_ctx.stream()>>>(
thrust::raw_pointer_cast(src_outputs),
num_edges,
table_size,
keys,
values);
#ifdef PADDLE_WITH_HIP
hipFree(keys);
hipFree(values);
hipFree(key_index);
#else
cudaFree(keys);
cudaFree(values);
cudaFree(key_index);
#endif
}
template <typename T, typename Context>
void BufferReindex(const Context& dev_ctx,
const T* inputs,
thrust::device_ptr<T> src_outputs,
thrust::device_vector<T>* out_nodes,
int num_inputs,
int* hashtable_value,
int* hashtable_index,
int num_edges) {
out_nodes->resize(num_inputs + num_edges);
thrust::copy(inputs, inputs + num_inputs, out_nodes->begin());
thrust::copy(
src_outputs, src_outputs + num_edges, out_nodes->begin() + num_inputs);
thrust::device_vector<T> unique_nodes;
unique_nodes.clear();
// Fill hash table
FillBufferHashTable<T, Context>(dev_ctx,
thrust::raw_pointer_cast(out_nodes->data()),
out_nodes->size(),
&unique_nodes,
hashtable_value,
hashtable_index);
out_nodes->resize(unique_nodes.size());
thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes->begin());
// Fill outputs with reindex result.
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (num_edges + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
ReindexSrcOutput<T><<<grid, block, 0, dev_ctx.stream()>>>(
thrust::raw_pointer_cast(src_outputs), num_edges, hashtable_value);
ResetBufferHashTable<T, Context>(dev_ctx,
thrust::raw_pointer_cast(out_nodes->data()),
out_nodes->size(),
&unique_nodes,
hashtable_value,
hashtable_index);
}
template <typename T, int BLOCK_WARPS, int TILE_SIZE>
__global__ void GetDstEdgeCUDAKernel(const int64_t num_rows,
const int* in_rows,
const int* dst_counts,
const int* dst_ptr,
T* dst_outputs) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
while (out_row < last_row) {
const int row = in_rows[out_row];
const int dst_sample_size = dst_counts[out_row];
const int out_row_start = dst_ptr[out_row];
for (int idx = threadIdx.x; idx < dst_sample_size; idx += WARP_SIZE) {
dst_outputs[out_row_start + idx] = row;
}
out_row += BLOCK_WARPS;
}
}
template <typename T, typename Context>
void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& neighbors,
const DenseTensor& count,
paddle::optional<const DenseTensor&> hashtable_value,
paddle::optional<const DenseTensor&> hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes) {
const T* x_data = x.data<T>();
const T* neighbors_data = neighbors.data<T>();
const int* count_data = count.data<int>();
const int bs = x.dims()[0];
const int num_edges = neighbors.dims()[0];
reindex_src->Resize({num_edges});
T* reindex_src_data = dev_ctx.template Alloc<T>(reindex_src);
thrust::device_ptr<T> src_outputs(reindex_src_data);
thrust::device_vector<T> unique_nodes;
thrust::copy(neighbors_data, neighbors_data + num_edges, src_outputs);
if (flag_buffer_hashtable) {
// Here we directly use buffer tensor to act as a hash table.
DenseTensor hashtable_value_out(hashtable_value->type());
const auto* ph_value = hashtable_value.get_ptr();
hashtable_value_out.ShareDataWith(*ph_value);
DenseTensor hashtable_index_out(hashtable_index->type());
const auto* ph_index = hashtable_index.get_ptr();
hashtable_index_out.ShareDataWith(*ph_index);
int* hashtable_value_data =
hashtable_value_out.mutable_data<int>(dev_ctx.GetPlace());
int* hashtable_index_data =
hashtable_index_out.mutable_data<int>(dev_ctx.GetPlace());
BufferReindex<T, Context>(dev_ctx,
x_data,
src_outputs,
&unique_nodes,
bs,
hashtable_value_data,
hashtable_index_data,
num_edges);
} else {
Reindex<T, Context>(
dev_ctx, x_data, src_outputs, &unique_nodes, bs, num_edges);
}
// Get reindex dst edge.
thrust::device_vector<int> unique_dst_reindex(bs);
thrust::sequence(unique_dst_reindex.begin(), unique_dst_reindex.end());
thrust::device_vector<int> dst_ptr(bs);
thrust::exclusive_scan(count_data, count_data + bs, dst_ptr.begin());
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
reindex_dst->Resize({num_edges});
T* reindex_dst_data = dev_ctx.template Alloc<T>(reindex_dst);
GetDstEdgeCUDAKernel<T,
BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, dev_ctx.stream()>>>(
bs,
thrust::raw_pointer_cast(unique_dst_reindex.data()),
count_data,
thrust::raw_pointer_cast(dst_ptr.data()),
reindex_dst_data);
out_nodes->Resize({static_cast<int>(unique_nodes.size())});
T* out_nodes_data = dev_ctx.template Alloc<T>(out_nodes);
thrust::copy(unique_nodes.begin(), unique_nodes.end(), out_nodes_data);
}
} // namespace phi
PD_REGISTER_KERNEL(
graph_reindex, GPU, ALL_LAYOUT, phi::GraphReindexKernel, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/transform.h>
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#else
#include <cuda_runtime.h>
#include <curand_kernel.h>
#endif
#include "paddle/phi/kernels/graph_sample_neighbors_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
struct DegreeFunctor {
const T* col_ptr;
HOSTDEVICE explicit inline DegreeFunctor(const T* x) { this->col_ptr = x; }
HOSTDEVICE inline int operator()(T i) const {
return col_ptr[i + 1] - col_ptr[i];
}
};
struct MaxFunctor {
int cap;
HOSTDEVICE explicit inline MaxFunctor(int cap) { this->cap = cap; }
HOSTDEVICE inline int operator()(int x) const {
if (x > cap) {
return cap;
}
return x;
}
};
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void SampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_nodes,
const T* nodes,
const T* row,
const T* col_ptr,
T* output,
int* output_ptr,
int* output_idxs) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_nodes);
#ifdef PADDLE_WITH_HIP
hiprandState rng;
hiprand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x,
0,
&rng);
#else
curandState rng;
curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x,
0,
&rng);
#endif
while (out_row < last_row) {
T node = nodes[out_row];
T in_row_start = col_ptr[node];
int deg = col_ptr[node + 1] - in_row_start;
int out_row_start = output_ptr[out_row];
if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
output[out_row_start + idx] = row[in_row_start + idx];
}
} else {
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
output_idxs[out_row_start + idx] = idx;
}
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) {
#ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1);
#else
const int num = curand(&rng) % (idx + 1);
#endif
if (num < k) {
atomicMax(reinterpret_cast<unsigned int*>( // NOLINT
output_idxs + out_row_start + num),
static_cast<unsigned int>(idx)); // NOLINT
}
}
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
T perm_idx = output_idxs[out_row_start + idx] + in_row_start;
output[out_row_start + idx] = row[perm_idx];
}
}
out_row += BLOCK_WARPS;
}
}
template <typename T, typename Context>
int GetTotalSampleNum(const thrust::device_ptr<const T> input,
const T* col_ptr,
thrust::device_ptr<int> output_count,
int sample_size,
int bs) {
thrust::transform(input, input + bs, output_count, DegreeFunctor<T>(col_ptr));
if (sample_size >= 0) {
thrust::transform(
output_count, output_count + bs, output_count, MaxFunctor(sample_size));
}
int total_sample_num = thrust::reduce(output_count, output_count + bs);
return total_sample_num;
}
template <typename T, typename Context>
void SampleNeighbors(const Context& dev_ctx,
const T* row,
const T* col_ptr,
const thrust::device_ptr<const T> input,
thrust::device_ptr<T> output,
thrust::device_ptr<int> output_count,
int sample_size,
int bs,
int total_sample_num) {
thrust::device_vector<int> output_ptr;
thrust::device_vector<int> output_idxs;
output_ptr.resize(bs);
output_idxs.resize(total_sample_num);
thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0);
constexpr int WARP_SIZE = 32;
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
SampleKernel<T,
WARP_SIZE,
BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, dev_ctx.stream()>>>(
0,
sample_size,
bs,
thrust::raw_pointer_cast(input),
row,
col_ptr,
thrust::raw_pointer_cast(output),
thrust::raw_pointer_cast(output_ptr.data()),
thrust::raw_pointer_cast(output_idxs.data()));
}
template <typename T>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_rows,
const T* in_rows,
T* src,
const T* dst_count) {
#ifdef PADDLE_WITH_HIP
hiprandState rng;
hiprand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#else
curandState rng;
curand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif
CUDA_KERNEL_LOOP(out_row, num_rows) {
const T row = in_rows[out_row];
const T in_row_start = dst_count[row];
const int deg = dst_count[row + 1] - in_row_start;
int split;
T tmp;
if (k < deg) {
if (deg < 2 * k) {
split = k;
} else {
split = deg - k;
}
for (int idx = deg - 1; idx >= split; idx--) {
#ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1);
#else
const int num = curand(&rng) % (idx + 1);
#endif
src[in_row_start + idx] = static_cast<T>(
atomicExch(reinterpret_cast<unsigned long long int*>( // NOLINT
src + in_row_start + num),
static_cast<unsigned long long int>( // NOLINT
src[in_row_start + idx])));
}
}
}
}
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void GatherEdge(int k,
int64_t num_rows,
const T* in_rows,
const T* src,
const T* dst_count,
T* outputs,
int* output_ptr,
T* perm_data) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
while (out_row < last_row) {
const T row = in_rows[out_row];
const T in_row_start = dst_count[row];
const int deg = dst_count[row + 1] - in_row_start;
const T out_row_start = output_ptr[out_row];
if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
const T in_idx = in_row_start + idx;
outputs[out_row_start + idx] = src[in_idx];
}
} else {
int split = k;
int begin, end;
if (deg < 2 * k) {
begin = 0;
end = k;
} else {
begin = deg - k;
end = deg;
}
for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) {
outputs[out_row_start + idx - begin] =
src[perm_data[in_row_start + idx]];
}
}
out_row += BLOCK_WARPS;
}
}
template <typename T, typename Context>
void FisherYatesSampleNeighbors(const Context& dev_ctx,
const T* row,
const T* col_ptr,
T* perm_data,
const thrust::device_ptr<const T> input,
thrust::device_ptr<T> output,
thrust::device_ptr<int> output_count,
int sample_size,
int bs,
int total_sample_num) {
thrust::device_vector<int> output_ptr;
output_ptr.resize(bs);
thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0);
#ifdef PADDLE_WITH_HIP
int block = 256;
#else
int block = 1024;
#endif
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (bs + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
FisherYatesSampleKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
0, sample_size, bs, thrust::raw_pointer_cast(input), perm_data, col_ptr);
constexpr int GATHER_WARP_SIZE = 32;
constexpr int GATHER_BLOCK_WARPS = 128 / GATHER_WARP_SIZE;
constexpr int GATHER_TILE_SIZE = GATHER_BLOCK_WARPS * 16;
const dim3 gather_block(GATHER_WARP_SIZE, GATHER_BLOCK_WARPS);
const dim3 gather_grid((bs + GATHER_TILE_SIZE - 1) / GATHER_TILE_SIZE);
GatherEdge<
T,
GATHER_WARP_SIZE,
GATHER_BLOCK_WARPS,
GATHER_TILE_SIZE><<<gather_grid, gather_block, 0, dev_ctx.stream()>>>(
sample_size,
bs,
thrust::raw_pointer_cast(input),
row,
col_ptr,
thrust::raw_pointer_cast(output),
thrust::raw_pointer_cast(output_ptr.data()),
perm_data);
}
template <typename T, typename Context>
void GraphSampleNeighborsKernel(
const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& x,
paddle::optional<const DenseTensor&> eids,
paddle::optional<const DenseTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids) {
auto* row_data = row.data<T>();
auto* col_ptr_data = col_ptr.data<T>();
auto* x_data = x.data<T>();
int bs = x.dims()[0];
const thrust::device_ptr<const T> input(x_data);
out_count->Resize({bs});
int* out_count_data = dev_ctx.template Alloc<int>(out_count);
thrust::device_ptr<int> output_count(out_count_data);
int total_sample_size = GetTotalSampleNum<T, Context>(
input, col_ptr_data, output_count, sample_size, bs);
out->Resize({static_cast<int>(total_sample_size)});
T* out_data = dev_ctx.template Alloc<T>(out);
thrust::device_ptr<T> output(out_data);
if (!flag_perm_buffer) {
SampleNeighbors<T, Context>(dev_ctx,
row_data,
col_ptr_data,
input,
output,
output_count,
sample_size,
bs,
total_sample_size);
} else {
DenseTensor perm_buffer_out(perm_buffer->type());
const auto* p_perm_buffer = perm_buffer.get_ptr();
perm_buffer_out.ShareDataWith(*p_perm_buffer);
T* perm_buffer_out_data =
perm_buffer_out.mutable_data<T>(dev_ctx.GetPlace());
FisherYatesSampleNeighbors<T, Context>(dev_ctx,
row_data,
col_ptr_data,
perm_buffer_out_data,
input,
output,
output_count,
sample_size,
bs,
total_sample_size);
}
}
} // namespace phi
PD_REGISTER_KERNEL(graph_sample_neighbors,
GPU,
ALL_LAYOUT,
phi::GraphSampleNeighborsKernel,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphReindexKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& neighbors,
const DenseTensor& count,
paddle::optional<const DenseTensor&> hashtable_value,
paddle::optional<const DenseTensor&> hashtable_index,
bool flag_buffer_hashtable,
DenseTensor* reindex_src,
DenseTensor* reindex_dst,
DenseTensor* out_nodes);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GraphSampleNeighborsKernel(
const Context& dev_ctx,
const DenseTensor& row,
const DenseTensor& col_ptr,
const DenseTensor& x,
paddle::optional<const DenseTensor&> eids,
paddle::optional<const DenseTensor&> perm_buffer,
int sample_size,
bool return_eids,
bool flag_perm_buffer,
DenseTensor* out,
DenseTensor* out_count,
DenseTensor* out_eids);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GraphReindexOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_reindex",
{"X", "Neighbors", "Count", "HashTable_Value", "HashTable_Index"},
{"flag_buffer_hashtable"},
{"Reindex_Src", "Reindex_Dst", "Out_Nodes"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_reindex, phi::GraphReindexOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GraphSampleNeighborsOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("graph_sample_neighbors",
{"Row", "Col_Ptr", "X", "Eids", "Perm_Buffer"},
{"sample_size", "return_eids", "flag_perm_buffer"},
{"Out", "Out_Count", "Out_Eids"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(graph_sample_neighbors,
phi::GraphSampleNeighborsOpArgumentMapping);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class TestGraphReindex(unittest.TestCase):
def setUp(self):
self.x = np.arange(5).astype("int64")
self.neighbors = np.random.randint(100, size=20).astype("int64")
self.count = np.array([2, 8, 4, 3, 3], dtype="int32")
# Get numpy result.
out_nodes = list(self.x)
for neighbor in self.neighbors:
if neighbor not in out_nodes:
out_nodes.append(neighbor)
self.out_nodes = np.array(out_nodes, dtype="int64")
reindex_dict = {node: ind for ind, node in enumerate(self.out_nodes)}
self.reindex_src = np.array(
[reindex_dict[node] for node in self.neighbors])
reindex_dst = []
for node, c in zip(self.x, self.count):
for i in range(c):
reindex_dst.append(reindex_dict[node])
self.reindex_dst = np.array(reindex_dst, dtype="int64")
self.num_nodes = np.max(np.concatenate([self.x, self.neighbors])) + 1
def test_reindex_result(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
neighbors = paddle.to_tensor(self.neighbors)
count = paddle.to_tensor(self.count)
value_buffer = paddle.full([self.num_nodes], -1, dtype="int32")
index_buffer = paddle.full([self.num_nodes], -1, dtype="int32")
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count)
self.assertTrue(np.allclose(self.reindex_src, reindex_src))
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst))
self.assertTrue(np.allclose(self.out_nodes, out_nodes))
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count,
value_buffer, index_buffer,
flag_buffer_hashtable=True)
self.assertTrue(np.allclose(self.reindex_src, reindex_src))
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst))
self.assertTrue(np.allclose(self.out_nodes, out_nodes))
def test_reindex_result_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.x.dtype)
neighbors = paddle.static.data(
name="neighbors",
shape=self.neighbors.shape,
dtype=self.neighbors.dtype)
count = paddle.static.data(
name="count", shape=self.count.shape, dtype=self.count.dtype)
value_buffer = paddle.static.data(
name="value_buffer", shape=[self.num_nodes], dtype="int32")
index_buffer = paddle.static.data(
name="index_buffer", shape=[self.num_nodes], dtype="int32")
reindex_src_1, reindex_dst_1, out_nodes_1 = \
paddle.incubate.graph_reindex(x, neighbors, count)
reindex_src_2, reindex_dst_2, out_nodes_2 = \
paddle.incubate.graph_reindex(x, neighbors, count,
value_buffer, index_buffer,
flag_buffer_hashtable=True)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(feed={
'x': self.x,
'neighbors': self.neighbors,
'count': self.count,
'value_buffer': np.full(
[self.num_nodes], -1, dtype="int32"),
'index_buffer': np.full(
[self.num_nodes], -1, dtype="int32")
},
fetch_list=[
reindex_src_1, reindex_dst_1, out_nodes_1,
reindex_src_2, reindex_dst_2, out_nodes_2
])
reindex_src_1, reindex_dst_1, out_nodes_1, reindex_src_2, \
reindex_dst_2, out_nodes_2 = ret
self.assertTrue(np.allclose(self.reindex_src, reindex_src_1))
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst_1))
self.assertTrue(np.allclose(self.out_nodes, out_nodes_1))
self.assertTrue(np.allclose(self.reindex_src, reindex_src_2))
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst_2))
self.assertTrue(np.allclose(self.out_nodes, out_nodes_2))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class TestGraphSampleNeighbors(unittest.TestCase):
def setUp(self):
num_nodes = 20
edges = np.random.randint(num_nodes, size=(100, 2))
edges = np.unique(edges, axis=0)
self.edges_id = np.arange(0, len(edges)).astype("int64")
sorted_edges = edges[np.argsort(edges[:, 1])]
# Calculate dst index cumsum counts, also means colptr
dst_count = np.zeros(num_nodes)
dst_src_dict = {}
for dst in range(0, num_nodes):
true_index = sorted_edges[:, 1] == dst
dst_count[dst] = np.sum(true_index)
dst_src_dict[dst] = sorted_edges[:, 0][true_index]
dst_count = dst_count.astype("int64")
colptr = np.cumsum(dst_count)
colptr = np.insert(colptr, 0, 0)
self.row = sorted_edges[:, 0].astype("int64")
self.colptr = colptr.astype("int64")
self.nodes = np.unique(np.random.randint(
num_nodes, size=5)).astype("int64")
self.sample_size = 5
self.dst_src_dict = dst_src_dict
def test_sample_result(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
out_neighbors, out_count = paddle.incubate.graph_sample_neighbors(
row, colptr, nodes, sample_size=self.sample_size)
out_count_cumsum = paddle.cumsum(out_count)
for i in range(len(out_count)):
if i == 0:
neighbors = out_neighbors[0:out_count_cumsum[i]]
else:
neighbors = out_neighbors[out_count_cumsum[i - 1]:
out_count_cumsum[i]]
# Ensure the correct sample size.
self.assertTrue(
out_count[i] == self.sample_size or
out_count[i] == len(self.dst_src_dict[self.nodes[i]]))
# Ensure no repetitive sample neighbors.
self.assertTrue(
neighbors.shape[0] == paddle.unique(neighbors).shape[0])
# Ensure the correct sample neighbors.
in_neighbors = np.isin(neighbors.numpy(),
self.dst_src_dict[self.nodes[i]])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_sample_result_fisher_yates_sampling(self):
paddle.disable_static()
if fluid.core.is_compiled_with_cuda():
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
perm_buffer = paddle.to_tensor(self.edges_id)
out_neighbors, out_count = paddle.incubate.graph_sample_neighbors(
row,
colptr,
nodes,
perm_buffer=perm_buffer,
sample_size=self.sample_size,
flag_perm_buffer=True)
out_count_cumsum = paddle.cumsum(out_count)
for i in range(len(out_count)):
if i == 0:
neighbors = out_neighbors[0:out_count_cumsum[i]]
else:
neighbors = out_neighbors[out_count_cumsum[i - 1]:
out_count_cumsum[i]]
# Ensure the correct sample size.
self.assertTrue(
out_count[i] == self.sample_size or
out_count[i] == len(self.dst_src_dict[self.nodes[i]]))
# Ensure no repetitive sample neighbors.
self.assertTrue(
neighbors.shape[0] == paddle.unique(neighbors).shape[0])
# Ensure the correct sample neighbors.
in_neighbors = np.isin(neighbors.numpy(),
self.dst_src_dict[self.nodes[i]])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_sample_result_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype)
out_neighbors, out_count = paddle.incubate.graph_sample_neighbors(
row, colptr, nodes, sample_size=self.sample_size)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(feed={
'row': self.row,
'colptr': self.colptr,
'nodes': self.nodes
},
fetch_list=[out_neighbors, out_count])
out_neighbors, out_count = ret
out_count_cumsum = np.cumsum(out_count)
out_neighbors = np.split(out_neighbors, out_count_cumsum)[:-1]
for neighbors, node, count in zip(out_neighbors, self.nodes,
out_count):
self.assertTrue(count == self.sample_size or
count == len(self.dst_src_dict[node]))
self.assertTrue(
neighbors.shape[0] == np.unique(neighbors).shape[0])
in_neighbors = np.isin(neighbors, self.dst_src_dict[node])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_raise_errors(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
def check_eid_error():
paddle.incubate.graph_sample_neighbors(
row,
colptr,
nodes,
sample_size=self.sample_size,
return_eids=True)
def check_perm_buffer_error():
paddle.incubate.graph_sample_neighbors(
row,
colptr,
nodes,
sample_size=self.sample_size,
flag_perm_buffer=True)
self.assertRaises(ValueError, check_eid_error)
self.assertRaises(ValueError, check_perm_buffer_error)
def test_sample_result_with_eids(self):
# Note: Currently return eid results is not initialized.
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
eids = paddle.to_tensor(self.edges_id)
out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors(
row,
colptr,
nodes,
eids=eids,
sample_size=self.sample_size,
return_eids=True)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype)
eids = paddle.static.data(
name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype)
out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors(
row,
colptr,
nodes,
eids,
sample_size=self.sample_size,
return_eids=True)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(feed={
'row': self.row,
'colptr': self.colptr,
'nodes': self.nodes,
'eids': self.edges_id
},
fetch_list=[out_neighbors, out_count])
if __name__ == "__main__":
unittest.main()
......@@ -21,6 +21,8 @@ from .operators import softmax_mask_fuse_upper_triangle # noqa: F401
from .operators import softmax_mask_fuse # noqa: F401
from .operators import graph_send_recv
from .operators import graph_khop_sampler
from .operators import graph_sample_neighbors
from .operators import graph_reindex
from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
......@@ -37,6 +39,8 @@ __all__ = [
'softmax_mask_fuse',
'graph_send_recv',
'graph_khop_sampler',
'graph_sample_neighbors',
'graph_reindex',
'segment_sum',
'segment_mean',
'segment_max',
......
......@@ -17,3 +17,5 @@ from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
from .graph_send_recv import graph_send_recv #noqa: F401
from .graph_khop_sampler import graph_khop_sampler #noqa: F401
from .graph_sample_neighbors import graph_sample_neighbors #noqa: F401
from .graph_reindex import graph_reindex #noqa: F401
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid import core
from paddle import _C_ops
def graph_reindex(x,
neighbors,
count,
value_buffer=None,
index_buffer=None,
flag_buffer_hashtable=False,
name=None):
"""
Graph Reindex API.
This API is mainly used in Graph Learning domain, which should be used
in conjunction with `graph_sample_neighbors` API. And the main purpose
is to reindex the ids information of the input nodes, and return the
corresponding graph edges after reindex.
Take input nodes x = [0, 1, 2] as an example.
If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
then we know that the neighbors of 0 is [8, 9], the neighbors of 1
is [0, 4, 7], and the neighbors of 2 is [6, 7].
Args:
x (Tensor): The input nodes which we sample neighbors for. The available
data type is int32, int64.
neighbors (Tensor): The neighbors of the input nodes `x`. The data type
should be the same with `x`.
count (Tensor): The neighbor count of the input nodes `x`. And the
data type should be int32.
value_buffer (Tensor|None): Value buffer for hashtable. The data type should
be int32, and should be filled with -1.
index_buffer (Tensor|None): Index buffer for hashtable. The data type should
be int32, and should be filled with -1.
flag_buffer_hashtable (bool): Whether to use buffer for hashtable to speed up.
Default is False. Only useful for gpu version currently.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
reindex_src (Tensor): The source node index of graph edges after reindex.
reindex_dst (Tensor): The destination node index of graph edges after reindex.
out_nodes (Tensor): The index of unique input nodes and neighbors before reindex,
where we put the input nodes `x` in the front, and put neighbor
nodes in the back.
Examples:
.. code-block:: python
import paddle
x = [0, 1, 2]
neighbors = [8, 9, 0, 4, 7, 6, 7]
count = [2, 3, 2]
x = paddle.to_tensor(x, dtype="int64")
neighbors = paddle.to_tensor(neighbors, dtype="int64")
count = paddle.to_tensor(count, dtype="int32")
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count)
# reindex_src: [3, 4, 0, 5, 6, 7, 6]
# reindex_dst: [0, 0, 1, 1, 1, 2, 2]
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]
"""
if flag_buffer_hashtable:
if value_buffer is None or index_buffer is None:
raise ValueError(f"`value_buffer` and `index_buffer` should not"
"be None if `flag_buffer_hashtable` is True.")
if _non_static_mode():
reindex_src, reindex_dst, out_nodes = \
_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer,
"flag_buffer_hashtable", flag_buffer_hashtable)
return reindex_src, reindex_dst, out_nodes
check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex")
check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"),
"graph_reindex")
check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex")
if flag_buffer_hashtable:
check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"),
"graph_reindex")
check_variable_and_dtype(index_buffer, "HashTable_Value", ("int32"),
"graph_reindex")
helper = LayerHelper("graph_reindex", **locals())
reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="graph_reindex",
inputs={
"X": x,
"Neighbors": neighbors,
"Count": count,
"HashTable_Value": value_buffer if flag_buffer_hashtable else None,
"HashTable_Index": index_buffer if flag_buffer_hashtable else None,
},
outputs={
"Reindex_Src": reindex_src,
"Reindex_Dst": reindex_dst,
"Out_Nodes": out_nodes
},
attrs={"flag_buffer_hashtable": flag_buffer_hashtable})
return reindex_src, reindex_dst, out_nodes
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid import core
from paddle import _C_ops
def graph_sample_neighbors(row,
colptr,
input_nodes,
eids=None,
perm_buffer=None,
sample_size=-1,
return_eids=False,
flag_perm_buffer=False,
name=None):
"""
Graph Sample Neighbors API.
This API is mainly used in Graph Learning domain, and the main purpose is to
provide high performance of graph sampling method. For example, we get the
CSC(Compressed Sparse Column) format of the input graph edges as `row` and
`colptr`, so as to convert graph data into a suitable format for sampling.
`input_nodes` means the nodes we need to sample neighbors, and `sample_sizes`
means the number of neighbors and number of layers we want to sample.
Besides, we support fisher-yates sampling in GPU version.
Args:
row (Tensor): One of the components of the CSC format of the input graph, and
the shape should be [num_edges, 1] or [num_edges]. The available
data type is int32, int64.
colptr (Tensor): One of the components of the CSC format of the input graph,
and the shape should be [num_nodes + 1, 1] or [num_nodes + 1].
The data type should be the same with `row`.
input_nodes (Tensor): The input nodes we need to sample neighbors for, and the
data type should be the same with `row`.
eids (Tensor): The eid information of the input graph. If return_eids is True,
then `eids` should not be None. The data type should be the
same with `row`. Default is None.
perm_buffer (Tensor): Permutation buffer for fisher-yates sampling. If `flag_perm_buffer`
is True, then `perm_buffer` should not be None. The data type should
be the same with `row`. Default is None.
sample_size (int): The number of neighbors we need to sample. Default value is
-1, which means returning all the neighbors of the input nodes.
return_eids (bool): Whether to return eid information of sample edges. Default is False.
flag_perm_buffer (bool): Using the permutation for fisher-yates sampling in GPU. Default
value is false, means not using it.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
out_neighbors (Tensor): The sample neighbors of the input nodes.
out_count (Tensor): The number of sampling neighbors of each input node, and the shape
should be the same with `input_nodes`.
out_eids (Tensor): If `return_eids` is True, we will return the eid information of the
sample edges.
Examples:
.. code-block:: python
import paddle
# edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4),
# (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8)
row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7]
colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13]
nodes = [0, 8, 1, 2]
sample_size = 2
row = paddle.to_tensor(row, dtype="int64")
colptr = paddle.to_tensor(colptr, dtype="int64")
nodes = paddle.to_tensor(nodes, dtype="int64")
out_neighbors, out_count = \
paddle.incubate.graph_sample_neighbors(row, colptr, nodes,
sample_size=sample_size)
"""
if return_eids:
if eids is None:
raise ValueError(
f"`eids` should not be None if `return_eids` is True.")
if flag_perm_buffer:
if perm_buffer is None:
raise ValueError(
f"`perm_buffer` should not be None if `flag_perm_buffer`"
"is True.")
if _non_static_mode():
out_neighbors, out_count, out_eids = _C_ops.graph_sample_neighbors(
row, colptr, input_nodes, eids, perm_buffer, "sample_size",
sample_size, "return_eids", return_eids, "flag_perm_buffer",
flag_perm_buffer)
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
check_variable_and_dtype(row, "Row", ("int32", "int64"),
"graph_sample_neighbors")
check_variable_and_dtype(colptr, "Col_Ptr", ("int32", "int64"),
"graph_sample_neighbors")
check_variable_and_dtype(input_nodes, "X", ("int32", "int64"),
"graph_sample_neighbors")
if return_eids:
check_variable_and_dtype(eids, "Eids", ("int32", "int64"),
"graph_sample_neighbors")
if flag_perm_buffer:
check_variable_and_dtype(perm_buffer, "Perm_Buffer", ("int32", "int64"),
"graph_sample_neighbors")
helper = LayerHelper("graph_sample_neighbors", **locals())
out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype)
out_count = helper.create_variable_for_type_inference(dtype=row.dtype)
out_eids = helper.create_variable_for_type_inference(dtype=row.dtype)
helper.append_op(
type="graph_sample_neighbors",
inputs={
"Row": row,
"Col_Ptr": colptr,
"X": input_nodes,
"Eids": eids if return_eids else None,
"Perm_Buffer": perm_buffer if flag_perm_buffer else None
},
outputs={
"Out": out_neighbors,
"Out_Count": out_count,
"Out_Eids": out_eids
},
attrs={
"sample_size": sample_size,
"return_eids": return_eids,
"flag_perm_buffer": flag_perm_buffer
})
if return_eids:
return out_neighbors, out_count, out_eids
return out_neighbors, out_count
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册