未验证 提交 35f949b5 编写于 作者: S Siming Dai 提交者: GitHub

Add Khop Graph Sampler API (#39146)

* add the test case for the UVA

* add the context load for the uva

* Add graph_sample kernel

* Add graph_sample commit

* add new commit for graph_sample

* add unsigned long long int

* delete some remarks

* add cpu version

* add cuda eids

* add cpu eids

* delete _uva

* optimize speed: emplace_back, last_layer

* add to_uva_tensor

* add cpu return_eids choice

* add gpu return_eids choice

* add cpu reindex_nodes

* add gpu reindex_nodes

* rename op and add OMP for cpu

* add incubate api

* fix the compile problem for the PADDLE_ENFORE and different device

* fix the rcom and windows compile problem

* add unittest for graph_sample_neighbors

* fix cpu unittest and unique problem

* fix uva unittest, fix cuda unique problem

* fix the windows compile problem

* fix the windows rand_r compile problem

* add correct unittest, add src_eids dispensable

* delete black

* combine uva unittest

* mv Sample_index to Sample_Index; check input shape; fix random sample func

* delete memset & cudaMemset

* fix according to PR comments

* fix rocm ci

* modify function names according to the specification

* fix windows_openblas ci

* refine annotations, fix windows unittest, add default value for uva device_id, fix bug for input nodes with empty neighbors

* fix rocm ci

* rename graph_sample_neighbors as graph_khop_sampler, add incubate api doc

* add data type

* fix conflict
Co-authored-by: Nwawltor <fangzeyang0904@hotmail.com>
上级 552db8dc
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace operators {
template <typename IdType>
inline __device__ size_t Hash(IdType id, int64_t size) {
return id % size;
}
template <typename IdType>
inline __device__ bool AttemptInsert(size_t pos, IdType id, int64_t index,
IdType* keys, int64_t* key_index) {
if (sizeof(IdType) == 4) {
const IdType key =
atomicCAS(reinterpret_cast<unsigned int*>(&keys[pos]),
static_cast<unsigned int>(-1), static_cast<unsigned int>(id));
if (key == -1 || key == id) {
atomicMin(
reinterpret_cast<unsigned long long int*>(&key_index[pos]), // NOLINT
static_cast<unsigned long long int>(index)); // NOLINT
return true;
} else {
return false;
}
} else if (sizeof(IdType) == 8) {
const IdType key = atomicCAS(
reinterpret_cast<unsigned long long int*>(&keys[pos]), // NOLINT
static_cast<unsigned long long int>(-1), // NOLINT
static_cast<unsigned long long int>(id)); // NOLINT
if (key == -1 || key == id) {
atomicMin(
reinterpret_cast<unsigned long long int*>(&key_index[pos]), // NOLINT
static_cast<unsigned long long int>(index)); // NOLINT
return true;
} else {
return false;
}
}
}
template <typename IdType>
inline __device__ void Insert(IdType id, int64_t index, int64_t size,
IdType* keys, int64_t* key_index) {
size_t pos = Hash(id, size);
size_t delta = 1;
while (!AttemptInsert(pos, id, index, keys, key_index)) {
pos = Hash(pos + delta, size);
delta += 1;
}
}
template <typename IdType>
inline __device__ int64_t Search(IdType id, const IdType* keys, int64_t size) {
int64_t pos = Hash(id, size);
int64_t delta = 1;
while (keys[pos] != id) {
pos = Hash(pos + delta, size);
delta += 1;
}
return pos;
}
template <typename IdType>
__global__ void BuildHashTable(const IdType* items, int64_t num_items,
int64_t size, IdType* keys, int64_t* key_index) {
CUDA_KERNEL_LOOP_TYPE(index, num_items, int64_t) {
Insert(items[index], index, size, keys, key_index);
}
}
template <typename IdType>
__global__ void GetItemIndexCount(const IdType* items, int* item_count,
int64_t num_items, int64_t size,
const IdType* keys, int64_t* key_index) {
CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) {
int64_t pos = Search(items[i], keys, size);
if (key_index[pos] == i) {
item_count[i] = 1;
}
}
}
template <typename IdType>
__global__ void FillUniqueItems(const IdType* items, int64_t num_items,
int64_t size, IdType* unique_items,
const int* item_count, const IdType* keys,
IdType* values, int64_t* key_index) {
CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) {
int64_t pos = Search(items[i], keys, size);
if (key_index[pos] == i) {
values[pos] = item_count[i];
unique_items[item_count[i]] = items[i];
}
}
}
template <typename IdType>
__global__ void ReindexSrcOutput(IdType* src_output, int64_t num_items,
int64_t size, const IdType* keys,
const IdType* values) {
CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) {
int64_t pos = Search(src_output[i], keys, size);
src_output[i] = values[pos];
}
}
template <typename IdType>
__global__ void ReindexInputNodes(const IdType* nodes, int64_t num_items,
IdType* reindex_nodes, int64_t size,
const IdType* keys, const IdType* values) {
CUDA_KERNEL_LOOP_TYPE(i, num_items, int64_t) {
int64_t pos = Search(nodes[i], keys, size);
reindex_nodes[i] = values[pos];
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/graph_khop_sampler_op.h"
namespace paddle {
namespace operators {
void InputShapeCheck(const framework::DDim& dims, std::string tensor_name) {
if (dims.size() == 2) {
PADDLE_ENFORCE_EQ(dims[1], 1, platform::errors::InvalidArgument(
"The last dim of %s should be 1 when it "
"is 2D, but we get %d",
tensor_name, dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dims.size(), 1,
platform::errors::InvalidArgument(
"The %s should be 1D, when it is not 2D, but we get %d",
tensor_name, dims.size()));
}
}
class GraphKhopSamplerOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Row"), "Input", "Row", "GraphKhopSampler");
OP_INOUT_CHECK(ctx->HasInput("Col_Ptr"), "Input", "Col_Ptr",
"GraphKhopSampler");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphKhopSampler");
OP_INOUT_CHECK(ctx->HasOutput("Out_Src"), "Output", "Out_Src",
"GraphKhopSampler");
OP_INOUT_CHECK(ctx->HasOutput("Out_Dst"), "Output", "Out_Dst",
"GraphKhopSampler");
OP_INOUT_CHECK(ctx->HasOutput("Sample_Index"), "Output", "Sample_Index",
"GraphKhopSampler");
OP_INOUT_CHECK(ctx->HasOutput("Reindex_X"), "Output", "Reindex_X",
"GraphKhopSampler");
// Restrict all the inputs as 1-dim tensor, or 2-dim tensor with the second
// dim as 1.
InputShapeCheck(ctx->GetInputDim("Row"), "Row");
InputShapeCheck(ctx->GetInputDim("Col_Ptr"), "Col_Ptr");
InputShapeCheck(ctx->GetInputDim("X"), "X");
const std::vector<int>& sample_sizes =
ctx->Attrs().Get<std::vector<int>>("sample_sizes");
PADDLE_ENFORCE_EQ(
!sample_sizes.empty(), true,
platform::errors::InvalidArgument(
"The parameter 'sample_sizes' in GraphSampleOp must be set. "
"But received 'sample_sizes' is empty."));
const bool& return_eids = ctx->Attrs().Get<bool>("return_eids");
if (return_eids) {
OP_INOUT_CHECK(ctx->HasInput("Eids"), "Input", "Eids",
"GraphKhopSampler");
InputShapeCheck(ctx->GetInputDim("Eids"), "Eids");
OP_INOUT_CHECK(ctx->HasOutput("Out_Eids"), "Output", "Out_Eids",
"GraphKhopSampler");
ctx->SetOutputDim("Out_Eids", {-1});
}
ctx->SetOutputDim("Out_Src", {-1, 1});
ctx->SetOutputDim("Out_Dst", {-1, 1});
ctx->SetOutputDim("Sample_Index", {-1});
auto dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Reindex_X", dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Row"),
ctx.device_context());
}
};
class GraphKhopSamplerOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Row", "The src index tensor of graph edges after sorted by dst.");
AddInput("Eids", "The eids of the input graph edges.").AsDispensable();
AddInput("Col_Ptr",
"The cumulative sum of the number of src neighbors of dst index, "
"starts from 0, end with number of edges");
AddInput("X", "The input center nodes index tensor.");
AddOutput("Out_Src",
"The output src edges tensor after sampling and reindex.");
AddOutput("Out_Dst",
"The output dst edges tensor after sampling and reindex.");
AddOutput("Sample_Index",
"The original index of the center nodes and sampling nodes");
AddOutput("Reindex_X", "The reindex node id of the input nodes.");
AddOutput("Out_Eids", "The eids of the sample edges.").AsIntermediate();
AddAttr<std::vector<int>>(
"sample_sizes", "The sample sizes of graph sample neighbors method.")
.SetDefault({});
AddAttr<bool>("return_eids",
"Whether to return the eids of the sample edges.")
.SetDefault(false);
AddComment(R"DOC(
Graph Learning Sampling Neighbors operator, for graphsage sampling method.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(graph_khop_sampler, ops::GraphKhopSamplerOP,
ops::GraphKhopSamplerOpMaker);
REGISTER_OP_CPU_KERNEL(graph_khop_sampler,
ops::GraphKhopSamplerOpKernel<CPU, int32_t>,
ops::GraphKhopSamplerOpKernel<CPU, int64_t>);
此差异已折叠。
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdlib.h>
#include <numeric>
#include <random>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <class bidiiter>
void SampleUniqueNeighbors(bidiiter begin, bidiiter end, int num_samples) {
int left_num = std::distance(begin, end);
std::random_device rd;
std::mt19937 rng{rd()};
std::uniform_int_distribution<int> dice_distribution(
0, std::numeric_limits<int>::max());
for (int i = 0; i < num_samples; i++) {
bidiiter r = begin;
int random_step = dice_distribution(rng) % left_num;
std::advance(r, random_step);
std::swap(*begin, *r);
++begin;
--left_num;
}
}
template <class bidiiter>
void SampleUniqueNeighborsWithEids(bidiiter src_begin, bidiiter src_end,
bidiiter eid_begin, bidiiter eid_end,
int num_samples) {
int left_num = std::distance(src_begin, src_end);
std::random_device rd;
std::mt19937 rng{rd()};
std::uniform_int_distribution<int> dice_distribution(
0, std::numeric_limits<int>::max());
for (int i = 0; i < num_samples; i++) {
bidiiter r1 = src_begin, r2 = eid_begin;
int random_step = dice_distribution(rng) % left_num;
std::advance(r1, random_step);
std::advance(r2, random_step);
std::swap(*src_begin, *r1);
std::swap(*eid_begin, *r2);
++src_begin;
++eid_begin;
--left_num;
}
}
template <typename T>
void SampleNeighbors(const T* src, const T* dst_count, const T* src_eids,
std::vector<T>* inputs, std::vector<T>* outputs,
std::vector<T>* output_counts,
std::vector<T>* outputs_eids, int k, bool is_first_layer,
bool is_last_layer, bool return_eids) {
const size_t bs = inputs->size();
// Allocate the memory of outputs
// Collect the neighbors size
std::vector<std::vector<T>> out_src_vec;
std::vector<std::vector<T>> out_eids_vec;
// `sample_cumsum_sizes` record the start position and end position after the
// sample.
std::vector<size_t> sample_cumsum_sizes(bs + 1);
size_t total_neighbors = 0;
// `total_neighbors` the size of output after the sample
sample_cumsum_sizes[0] = total_neighbors;
for (size_t i = 0; i < bs; i++) {
T node = inputs->data()[i];
T begin = dst_count[node];
T end = dst_count[node + 1];
int cap = end - begin;
int sample_size = cap > k ? k : cap;
total_neighbors += sample_size;
sample_cumsum_sizes[i + 1] = total_neighbors;
std::vector<T> out_src;
out_src.resize(cap);
out_src_vec.emplace_back(out_src);
if (return_eids) {
std::vector<T> out_eids;
out_eids.resize(cap);
out_eids_vec.emplace_back(out_eids);
}
}
if (is_first_layer) {
PADDLE_ENFORCE_GT(total_neighbors, 0,
platform::errors::InvalidArgument(
"The input nodes `X` should have at "
"least one neighbors, but none of the "
"input nodes have neighbors."));
}
output_counts->resize(bs);
outputs->resize(total_neighbors);
if (return_eids) {
outputs_eids->resize(total_neighbors);
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Sample the neighbour parallelism
for (size_t i = 0; i < bs; i++) {
T node = inputs->data()[i];
T begin = dst_count[node];
T end = dst_count[node + 1];
int cap = end - begin;
if (k < cap) {
std::copy(src + begin, src + end, out_src_vec[i].begin());
if (return_eids) {
std::copy(src_eids + begin, src_eids + end, out_eids_vec[i].begin());
SampleUniqueNeighborsWithEids(
out_src_vec[i].begin(), out_src_vec[i].end(),
out_eids_vec[i].begin(), out_eids_vec[i].end(), k);
} else {
SampleUniqueNeighbors(out_src_vec[i].begin(), out_src_vec[i].end(), k);
}
*(output_counts->data() + i) = k;
} else {
std::copy(src + begin, src + end, out_src_vec[i].begin());
if (return_eids) {
std::copy(src_eids + begin, src_eids + end, out_eids_vec[i].begin());
}
*(output_counts->data() + i) = cap;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Copy the results parallelism
for (size_t i = 0; i < bs; i++) {
int sample_size = sample_cumsum_sizes[i + 1] - sample_cumsum_sizes[i];
std::copy(out_src_vec[i].begin(), out_src_vec[i].begin() + sample_size,
outputs->data() + sample_cumsum_sizes[i]);
if (return_eids) {
std::copy(out_eids_vec[i].begin(), out_eids_vec[i].begin() + sample_size,
outputs_eids->data() + sample_cumsum_sizes[i]);
}
}
if (!is_last_layer) {
std::sort(inputs->begin(), inputs->end());
std::vector<T> outputs_sort(outputs->size());
std::copy(outputs->begin(), outputs->end(), outputs_sort.begin());
std::sort(outputs_sort.begin(), outputs_sort.end());
auto outputs_sort_end =
std::unique(outputs_sort.begin(), outputs_sort.end());
outputs_sort.resize(std::distance(outputs_sort.begin(), outputs_sort_end));
std::vector<T> unique_outputs(outputs_sort.size());
auto unique_outputs_end = std::set_difference(
outputs_sort.begin(), outputs_sort.end(), inputs->begin(),
inputs->end(), unique_outputs.begin());
inputs->resize(std::distance(unique_outputs.begin(), unique_outputs_end));
std::copy(unique_outputs.begin(), unique_outputs_end, inputs->begin());
}
}
template <typename DeviceContext, typename T>
class GraphKhopSamplerOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// 1. Get sample neighbors operators' inputs.
auto* src = ctx.Input<Tensor>("Row");
auto* dst_count = ctx.Input<Tensor>("Col_Ptr");
auto* vertices = ctx.Input<Tensor>("X");
std::vector<int> sample_sizes = ctx.Attr<std::vector<int>>("sample_sizes");
bool return_eids = ctx.Attr<bool>("return_eids");
const T* src_data = src->data<T>();
const T* dst_count_data = dst_count->data<T>();
const T* p_vertices = vertices->data<T>();
const size_t bs = vertices->dims()[0];
// 2. Get unique input nodes(X).
std::vector<T> inputs(bs);
std::copy(p_vertices, p_vertices + bs, inputs.begin());
auto unique_inputs_end = std::unique(inputs.begin(), inputs.end());
inputs.resize(std::distance(inputs.begin(), unique_inputs_end));
// 3. Sample neighbors. We should distinguish w/o "Eids".
std::vector<T> outputs;
std::vector<T> output_counts;
std::vector<T> outputs_eids;
std::vector<std::vector<T>> dst_vec;
dst_vec.emplace_back(inputs);
std::vector<std::vector<T>> outputs_vec;
std::vector<std::vector<T>> output_counts_vec;
std::vector<std::vector<T>> outputs_eids_vec;
const size_t num_layers = sample_sizes.size();
bool is_last_layer = false, is_first_layer = true;
if (return_eids) {
auto* src_eids = ctx.Input<Tensor>("Eids");
const T* src_eids_data = src_eids->data<T>();
for (size_t i = 0; i < num_layers; i++) {
if (i == num_layers - 1) {
is_last_layer = true;
}
if (inputs.size() == 0) {
break;
}
if (i > 0) {
dst_vec.emplace_back(inputs);
is_first_layer = false;
}
SampleNeighbors<T>(src_data, dst_count_data, src_eids_data, &inputs,
&outputs, &output_counts, &outputs_eids,
sample_sizes[i], is_first_layer, is_last_layer,
return_eids);
outputs_vec.emplace_back(outputs);
output_counts_vec.emplace_back(output_counts);
outputs_eids_vec.emplace_back(outputs_eids);
}
} else {
for (size_t i = 0; i < num_layers; i++) {
if (i == num_layers - 1) {
is_last_layer = true;
}
if (inputs.size() == 0) {
break;
}
if (i > 0) {
is_first_layer = false;
dst_vec.emplace_back(inputs);
}
SampleNeighbors<T>(src_data, dst_count_data, nullptr, &inputs, &outputs,
&output_counts, &outputs_eids, sample_sizes[i],
is_first_layer, is_last_layer, return_eids);
outputs_vec.emplace_back(outputs);
output_counts_vec.emplace_back(output_counts);
outputs_eids_vec.emplace_back(outputs_eids);
}
}
// 4. Concat intermediate sample results.
int64_t unique_dst_size = 0, src_size = 0;
for (size_t i = 0; i < num_layers; i++) {
unique_dst_size += dst_vec[i].size();
src_size += outputs_vec[i].size();
}
std::vector<T> unique_dst_merge(unique_dst_size);
std::vector<T> src_merge(src_size);
std::vector<T> dst_sample_counts_merge(unique_dst_size);
auto unique_dst_merge_ptr = unique_dst_merge.begin();
auto src_merge_ptr = src_merge.begin();
auto dst_sample_counts_merge_ptr = dst_sample_counts_merge.begin();
// TODO(daisiming): We may try to use std::move in the future.
for (size_t i = 0; i < num_layers; i++) {
if (i == 0) {
unique_dst_merge_ptr = std::copy(dst_vec[i].begin(), dst_vec[i].end(),
unique_dst_merge.begin());
src_merge_ptr = std::copy(outputs_vec[i].begin(), outputs_vec[i].end(),
src_merge.begin());
dst_sample_counts_merge_ptr =
std::copy(output_counts_vec[i].begin(), output_counts_vec[i].end(),
dst_sample_counts_merge.begin());
} else {
unique_dst_merge_ptr = std::copy(dst_vec[i].begin(), dst_vec[i].end(),
unique_dst_merge_ptr);
src_merge_ptr = std::copy(outputs_vec[i].begin(), outputs_vec[i].end(),
src_merge_ptr);
dst_sample_counts_merge_ptr =
std::copy(output_counts_vec[i].begin(), output_counts_vec[i].end(),
dst_sample_counts_merge_ptr);
}
}
// 5. Return eids results.
if (return_eids) {
std::vector<T> eids_merge(src_size);
auto eids_merge_ptr = eids_merge.begin();
for (size_t i = 0; i < num_layers; i++) {
if (i == 0) {
eids_merge_ptr =
std::copy(outputs_eids_vec[i].begin(), outputs_eids_vec[i].end(),
eids_merge.begin());
} else {
eids_merge_ptr = std::copy(outputs_eids_vec[i].begin(),
outputs_eids_vec[i].end(), eids_merge_ptr);
}
}
auto* out_eids = ctx.Output<Tensor>("Out_Eids");
out_eids->Resize({static_cast<int>(eids_merge.size())});
T* p_out_eids = out_eids->mutable_data<T>(ctx.GetPlace());
std::copy(eids_merge.begin(), eids_merge.end(), p_out_eids);
}
int64_t num_sample_edges = std::accumulate(
dst_sample_counts_merge.begin(), dst_sample_counts_merge.end(), 0);
PADDLE_ENFORCE_EQ(
src_merge.size(), num_sample_edges,
platform::errors::PreconditionNotMet(
"Number of sample edges dismatch, the sample kernel has error."));
// 6. Reindex edges.
std::unordered_map<T, T> node_map;
std::vector<T> unique_nodes;
size_t reindex_id = 0;
for (size_t i = 0; i < unique_dst_merge.size(); i++) {
T node = unique_dst_merge[i];
unique_nodes.emplace_back(node);
node_map[node] = reindex_id++;
}
for (size_t i = 0; i < src_merge.size(); i++) {
T node = src_merge[i];
if (node_map.find(node) == node_map.end()) {
unique_nodes.emplace_back(node);
node_map[node] = reindex_id++;
}
src_merge[i] = node_map[node];
}
std::vector<T> dst_merge(src_merge.size());
size_t cnt = 0;
for (size_t i = 0; i < unique_dst_merge.size(); i++) {
for (T j = 0; j < dst_sample_counts_merge[i]; j++) {
T node = unique_dst_merge[i];
dst_merge[cnt++] = node_map[node];
}
}
// 7. Get Reindex_X for input nodes.
auto* reindex_x = ctx.Output<Tensor>("Reindex_X");
T* p_reindex_x = reindex_x->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < bs; i++) {
p_reindex_x[i] = node_map[p_vertices[i]];
}
// 8. Get operator's outputs.
auto* sample_index = ctx.Output<Tensor>("Sample_Index");
auto* out_src = ctx.Output<Tensor>("Out_Src");
auto* out_dst = ctx.Output<Tensor>("Out_Dst");
sample_index->Resize({static_cast<int>(unique_nodes.size())});
out_src->Resize({static_cast<int>(src_merge.size()), 1});
out_dst->Resize({static_cast<int>(src_merge.size()), 1});
T* p_sample_index = sample_index->mutable_data<T>(ctx.GetPlace());
T* p_out_src = out_src->mutable_data<T>(ctx.GetPlace());
T* p_out_dst = out_dst->mutable_data<T>(ctx.GetPlace());
std::copy(unique_nodes.begin(), unique_nodes.end(), p_sample_index);
std::copy(src_merge.begin(), src_merge.end(), p_out_src);
std::copy(dst_merge.begin(), dst_merge.end(), p_out_dst);
}
};
} // namespace operators
} // namespace paddle
......@@ -2563,6 +2563,72 @@ void BindImperative(py::module *m_ptr) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
#if defined(PADDLE_WITH_CUDA)
m.def("to_uva_tensor",
[](const py::object &obj, int device_id) {
const auto &tracer = imperative::GetCurrentTracer();
auto new_tensor = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
auto array = obj.cast<py::array>();
if (py::isinstance<py::array_t<int32_t>>(array)) {
SetUVATensorFromPyArray<int32_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<int64_t>>(array)) {
SetUVATensorFromPyArray<int64_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<float>>(array)) {
SetUVATensorFromPyArray<float>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<double>>(array)) {
SetUVATensorFromPyArray<double>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<int8_t>>(array)) {
SetUVATensorFromPyArray<int8_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<int16_t>>(array)) {
SetUVATensorFromPyArray<int16_t>(new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(
array)) {
SetUVATensorFromPyArray<paddle::platform::float16>(
new_tensor, array, device_id);
} else if (py::isinstance<py::array_t<bool>>(array)) {
SetUVATensorFromPyArray<bool>(new_tensor, array, device_id);
} else {
// obj may be any type, obj.cast<py::array>() may be failed,
// then the array.dtype will be string of unknown meaning.
PADDLE_THROW(platform::errors::InvalidArgument(
"Input object type error or incompatible array data type. "
"tensor.set() supports array with bool, float16, float32, "
"float64, int8, int16, int32, int64,"
"please check your input or input array data type."));
}
return new_tensor;
},
py::arg("obj"), py::arg("device_id") = 0,
py::return_value_policy::reference, R"DOC(
Returns tensor with the UVA(unified virtual addressing) created from numpy array.
Args:
obj(numpy.ndarray): The input numpy array, supporting bool, float16, float32,
float64, int8, int16, int32, int64 dtype currently.
device_id(int, optional): The destination GPU device id.
Default: 0, means current device.
Returns:
new_tensor(paddle.Tensor): Return the UVA Tensor with the sample dtype and
shape with the input numpy array.
Examples:
.. code-block:: python
# required: gpu
import numpy as np
import paddle
data = np.random.randint(10, size=(3, 4))
tensor = paddle.fluid.core.to_uva_tensor(data)
print(tensor)
)DOC");
#endif
#if defined(PADDLE_WITH_CUDA)
m.def(
"async_write",
......
......@@ -83,6 +83,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
{"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......
......@@ -448,6 +448,39 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
}
}
template <typename T>
void SetUVATensorFromPyArray(
const std::shared_ptr<paddle::imperative::VarBase> &self,
const py::array_t<T> &array, int device_id) {
#if defined(PADDLE_WITH_CUDA)
auto *self_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
std::vector<int64_t> dims;
dims.reserve(array.ndim());
int64_t numel = 1;
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.emplace_back(static_cast<int>(array.shape()[i]));
numel *= static_cast<int>(array.shape()[i]);
}
self_tensor->Resize(framework::make_ddim(dims));
auto data_type = framework::ToDataType(std::type_index(typeid(T)));
const auto &need_allocate_size = numel * framework::SizeOfType(data_type);
T *data_ptr;
cudaHostAlloc(reinterpret_cast<void **>(&data_ptr), need_allocate_size,
cudaHostAllocWriteCombined | cudaHostAllocMapped);
std::memcpy(data_ptr, array.data(), array.nbytes());
void *cuda_device_pointer = nullptr;
cudaHostGetDevicePointer(reinterpret_cast<void **>(&cuda_device_pointer),
reinterpret_cast<void *>(data_ptr), 0);
std::shared_ptr<memory::allocation::Allocation> holder =
std::make_shared<memory::allocation::Allocation>(
cuda_device_pointer, need_allocate_size,
platform::CUDAPlace(device_id));
self_tensor->ResetHolderWithType(holder, data_type);
#endif
}
template <typename T, size_t D>
void _sliceCompute(const framework::Tensor *in, framework::Tensor *out,
const platform::CPUDeviceContext &ctx,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class TestGraphKhopSampler(unittest.TestCase):
def setUp(self):
num_nodes = 20
edges = np.random.randint(num_nodes, size=(100, 2))
edges = np.unique(edges, axis=0)
edges_id = np.arange(0, len(edges))
sorted_edges = edges[np.argsort(edges[:, 1])]
sorted_eid = edges_id[np.argsort(edges[:, 1])]
# Calculate dst index cumsum counts.
dst_count = np.zeros(num_nodes)
dst_src_dict = {}
for dst in range(0, num_nodes):
true_index = sorted_edges[:, 1] == dst
dst_count[dst] = np.sum(true_index)
dst_src_dict[dst] = sorted_edges[:, 0][true_index]
dst_count = dst_count.astype("int64")
colptr = np.cumsum(dst_count)
colptr = np.insert(colptr, 0, 0)
self.row = sorted_edges[:, 0].astype("int64")
self.colptr = colptr.astype("int64")
self.sorted_eid = sorted_eid.astype("int64")
self.nodes = np.unique(np.random.randint(
num_nodes, size=5)).astype("int64")
self.sample_sizes = [5, 5]
self.dst_src_dict = dst_src_dict
def test_sample_result(self):
paddle.disable_static()
row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
edge_src, edge_dst, sample_index, reindex_nodes = \
paddle.incubate.graph_khop_sampler(row, colptr,
nodes, self.sample_sizes,
return_eids=False)
# Reindex edge_src and edge_dst to original index.
edge_src = edge_src.reshape([-1])
edge_dst = edge_dst.reshape([-1])
sample_index = sample_index.reshape([-1])
for i in range(len(edge_src)):
edge_src[i] = sample_index[edge_src[i]]
edge_dst[i] = sample_index[edge_dst[i]]
for n in self.nodes:
edge_src_n = edge_src[edge_dst == n]
if edge_src_n.shape[0] == 0:
continue
# Ensure no repetitive sample neighbors.
self.assertTrue(
edge_src_n.shape[0] == paddle.unique(edge_src_n).shape[0])
# Ensure the correct sample size.
self.assertTrue(edge_src_n.shape[0] == self.sample_sizes[0] or
edge_src_n.shape[0] == len(self.dst_src_dict[n]))
in_neighbors = np.isin(edge_src_n.numpy(), self.dst_src_dict[n])
# Ensure the correct sample neighbors.
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_uva_sample_result(self):
paddle.disable_static()
if paddle.fluid.core.is_compiled_with_cuda():
row = paddle.fluid.core.to_uva_tensor(
self.row.astype(self.row.dtype))
sorted_eid = paddle.fluid.core.to_uva_tensor(
self.sorted_eid.astype(self.sorted_eid.dtype))
colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes)
edge_src, edge_dst, sample_index, reindex_nodes, edge_eids = \
paddle.incubate.graph_khop_sampler(row, colptr,
nodes, self.sample_sizes,
sorted_eids=sorted_eid,
return_eids=True)
edge_src = edge_src.reshape([-1])
edge_dst = edge_dst.reshape([-1])
sample_index = sample_index.reshape([-1])
for i in range(len(edge_src)):
edge_src[i] = sample_index[edge_src[i]]
edge_dst[i] = sample_index[edge_dst[i]]
for n in self.nodes:
edge_src_n = edge_src[edge_dst == n]
if edge_src_n.shape[0] == 0:
continue
self.assertTrue(
edge_src_n.shape[0] == paddle.unique(edge_src_n).shape[0])
self.assertTrue(
edge_src_n.shape[0] == self.sample_sizes[0] or
edge_src_n.shape[0] == len(self.dst_src_dict[n]))
in_neighbors = np.isin(edge_src_n.numpy(), self.dst_src_dict[n])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_sample_result_static_with_eids(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype)
sorted_eids = paddle.static.data(
name="eids",
shape=self.sorted_eid.shape,
dtype=self.sorted_eid.dtype)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype)
edge_src, edge_dst, sample_index, reindex_nodes, edge_eids = \
paddle.incubate.graph_khop_sampler(row, colptr,
nodes, self.sample_sizes,
sorted_eids, True)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(feed={
'row': self.row,
'eids': self.sorted_eid,
'colptr': self.colptr,
'nodes': self.nodes
},
fetch_list=[edge_src, edge_dst, sample_index])
edge_src, edge_dst, sample_index = ret
edge_src = edge_src.reshape([-1])
edge_dst = edge_dst.reshape([-1])
sample_index = sample_index.reshape([-1])
for i in range(len(edge_src)):
edge_src[i] = sample_index[edge_src[i]]
edge_dst[i] = sample_index[edge_dst[i]]
for n in self.nodes:
edge_src_n = edge_src[edge_dst == n]
if edge_src_n.shape[0] == 0:
continue
self.assertTrue(
edge_src_n.shape[0] == np.unique(edge_src_n).shape[0])
self.assertTrue(
edge_src_n.shape[0] == self.sample_sizes[0] or
edge_src_n.shape[0] == len(self.dst_src_dict[n]))
in_neighbors = np.isin(edge_src_n, self.dst_src_dict[n])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
def test_sample_result_static_without_eids(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data(
name="row", shape=self.row.shape, dtype=self.row.dtype)
colptr = paddle.static.data(
name="colptr", shape=self.colptr.shape, dtype=self.colptr.dtype)
nodes = paddle.static.data(
name="nodes", shape=self.nodes.shape, dtype=self.nodes.dtype)
edge_src, edge_dst, sample_index, reindex_nodes = \
paddle.incubate.graph_khop_sampler(row, colptr,
nodes, self.sample_sizes)
exe = paddle.static.Executor(paddle.CPUPlace())
ret = exe.run(feed={
'row': self.row,
'colptr': self.colptr,
'nodes': self.nodes
},
fetch_list=[edge_src, edge_dst, sample_index])
edge_src, edge_dst, sample_index = ret
edge_src = edge_src.reshape([-1])
edge_dst = edge_dst.reshape([-1])
sample_index = sample_index.reshape([-1])
for i in range(len(edge_src)):
edge_src[i] = sample_index[edge_src[i]]
edge_dst[i] = sample_index[edge_dst[i]]
for n in self.nodes:
edge_src_n = edge_src[edge_dst == n]
if edge_src_n.shape[0] == 0:
continue
self.assertTrue(
edge_src_n.shape[0] == np.unique(edge_src_n).shape[0])
self.assertTrue(
edge_src_n.shape[0] == self.sample_sizes[0] or
edge_src_n.shape[0] == len(self.dst_src_dict[n]))
in_neighbors = np.isin(edge_src_n, self.dst_src_dict[n])
self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0])
if __name__ == "__main__":
unittest.main()
......@@ -15,7 +15,6 @@
import paddle
import unittest
import numpy as np
from paddle.fluid.core import LoDTensor as Tensor
class TestTensorCopyFrom(unittest.TestCase):
......@@ -28,5 +27,19 @@ class TestTensorCopyFrom(unittest.TestCase):
self.assertTrue(tensor.place.is_gpu_place())
class TestUVATensorFromNumpy(unittest.TestCase):
def test_uva_tensor_creation(self):
if paddle.fluid.core.is_compiled_with_cuda():
dtype_list = [
"int32", "int64", "float32", "float64", "float16", "int8",
"int16", "bool"
]
for dtype in dtype_list:
data = np.random.randint(10, size=[4, 5]).astype(dtype)
tensor = paddle.fluid.core.to_uva_tensor(data, 0)
self.assertTrue(tensor.place.is_gpu_place())
self.assertTrue(np.allclose(tensor.numpy(), data))
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,7 @@ from ..fluid.layer_helper import LayerHelper # noqa: F401
from .operators import softmax_mask_fuse_upper_triangle # noqa: F401
from .operators import softmax_mask_fuse # noqa: F401
from .operators import graph_send_recv
from .operators import graph_khop_sampler
from .tensor import segment_sum
from .tensor import segment_mean
from .tensor import segment_max
......@@ -33,6 +34,7 @@ __all__ = [
'softmax_mask_fuse_upper_triangle',
'softmax_mask_fuse',
'graph_send_recv',
'graph_khop_sampler',
'segment_sum',
'segment_mean',
'segment_max',
......
......@@ -16,3 +16,4 @@ from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle
from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
from .graph_send_recv import graph_send_recv #noqa: F401
from .graph_khop_sampler import graph_khop_sampler #noqa: F401
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid import core
from paddle import _C_ops
def graph_khop_sampler(row,
colptr,
input_nodes,
sample_sizes,
sorted_eids=None,
return_eids=False,
name=None):
"""
Graph Khop Sampler API.
This API is mainly used in Graph Learning domain, and the main purpose is to
provide high performance graph khop sampling method with subgraph reindex step.
For example, we get the CSC(Compressed Sparse Column) format of the input graph
edges as `row` and `colptr`, so as to covert graph data into a suitable format
for sampling. And the `input_nodes` means the nodes we need to sample neighbors,
and `sample_sizes` means the number of neighbors and number of layers we want
to sample.
**Note**:
Currently the API will reindex the output edges after finishing sampling. We
will add a choice or a new API for whether to reindex the edges in the near future.
Args:
row (Tensor): One of the components of the CSC format of the input graph, and
the shape should be [num_edges, 1] or [num_edges]. The available
data type is int32, int64.
colptr (Tensor): One of the components of the CSC format of the input graph,
and the shape should be [num_nodes + 1, 1] or [num_nodes].
The data type should be the same with `row`.
input_nodes (Tensor): The input nodes we need to sample neighbors for, and the
data type should be the same with `row`.
sample_sizes (list|tuple): The number of neighbors and number of layers we want
to sample. The data type should be int, and the shape
should only have one dimension.
sorted_eids (Tensor): The sorted edge ids, should not be None when `return_eids`
is True. The shape should be [num_edges, 1], and the data
type should be the same with `row`.
return_eids (bool): Whether to return the id of the sample edges. Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
edge_src (Tensor): The src index of the output edges, also means the first column of
the edges. The shape is [num_sample_edges, 1] currently.
edge_dst (Tensor): The dst index of the output edges, also means the second column
of the edges. The shape is [num_sample_edges, 1] currently.
sample_index (Tensor): The original id of the input nodes and sampled neighbor nodes.
reindex_nodes (Tensor): The reindex id of the input nodes.
edge_eids (Tensor): Return the id of the sample edges if `return_eids` is True.
Examples:
.. code-block:: python
import paddle
row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7]
colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13]
nodes = [0, 8, 1, 2]
sample_sizes = [2, 2]
row = paddle.to_tensor(row, dtype="int64")
colptr = paddle.to_tensor(colptr, dtype="int64")
nodes = paddle.to_tensor(nodes, dtype="int64")
edge_src, edge_dst, sample_index, reindex_nodes = \
paddle.incubate.graph_khop_sampler(row, colptr, nodes, sample_sizes, False)
"""
if in_dygraph_mode():
if return_eids:
if sorted_eids is None:
raise ValueError(f"`sorted_eid` should not be None "
f"if return_eids is True.")
edge_src, edge_dst, sample_index, reindex_nodes, edge_eids = \
_C_ops.graph_khop_sampler(row, sorted_eids,
colptr, input_nodes,
"sample_sizes", sample_sizes,
"return_eids", True)
return edge_src, edge_dst, sample_index, reindex_nodes, edge_eids
else:
edge_src, edge_dst, sample_index, reindex_nodes, _ = \
_C_ops.graph_khop_sampler(row, None,
colptr, input_nodes,
"sample_sizes", sample_sizes,
"return_eids", False)
return edge_src, edge_dst, sample_index, reindex_nodes
check_variable_and_dtype(row, "Row", ("int32", "int64"),
"graph_khop_sampler")
if return_eids:
if sorted_eids is None:
raise ValueError(f"`sorted_eid` should not be None "
f"if return_eids is True.")
check_variable_and_dtype(sorted_eids, "Eids", ("int32", "int64"),
"graph_khop_sampler")
check_variable_and_dtype(colptr, "Col_Ptr", ("int32", "int64"),
"graph_khop_sampler")
check_variable_and_dtype(input_nodes, "X", ("int32", "int64"),
"graph_khop_sampler")
helper = LayerHelper("graph_khop_sampler", **locals())
edge_src = helper.create_variable_for_type_inference(dtype=row.dtype)
edge_dst = helper.create_variable_for_type_inference(dtype=row.dtype)
sample_index = helper.create_variable_for_type_inference(dtype=row.dtype)
reindex_nodes = helper.create_variable_for_type_inference(dtype=row.dtype)
edge_eids = helper.create_variable_for_type_inference(dtype=row.dtype)
helper.append_op(
type="graph_khop_sampler",
inputs={
"Row": row,
"Eids": sorted_eids,
"Col_Ptr": colptr,
"X": input_nodes
},
outputs={
"Out_Src": edge_src,
"Out_Dst": edge_dst,
"Sample_Index": sample_index,
"Reindex_X": reindex_nodes,
"Out_Eids": edge_eids
},
attrs={"sample_sizes": sample_sizes,
"return_eids": return_eids})
if return_eids:
return edge_src, edge_dst, sample_index, reindex_nodes, edge_eids
else:
return edge_src, edge_dst, sample_index, reindex_nodes
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册