未验证 提交 59be2f3b 编写于 作者: S Siming Dai 提交者: GitHub

[GNN] Fix graph sample and data type bug (#45001)

上级 125e48c3
...@@ -670,7 +670,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -670,7 +670,7 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", .def("__init__",
[](imperative::VarBase &self, [](imperative::VarBase &self,
framework::proto::VarType::Type dtype, framework::proto::VarType::Type dtype,
const std::vector<int> &dims, const std::vector<int64_t> &dims,
const py::handle &name, const py::handle &name,
framework::proto::VarType::Type type, framework::proto::VarType::Type type,
bool persistable) { bool persistable) {
......
...@@ -191,10 +191,10 @@ static void ParseIndexingSlice(framework::LoDTensor* tensor, ...@@ -191,10 +191,10 @@ static void ParseIndexingSlice(framework::LoDTensor* tensor,
PyObject* slice_item = PyTuple_GetItem(index, i); PyObject* slice_item = PyTuple_GetItem(index, i);
infer_flags->push_back(1); infer_flags->push_back(1);
int dim_len = shape[dim]; int64_t dim_len = shape[dim];
if (PyCheckInteger(slice_item) || IsNumpyType(slice_item)) { if (PyCheckInteger(slice_item) || IsNumpyType(slice_item)) {
// integer, PyLong_AsLong supports both int and long // integer, PyLong_AsLong supports both int and long
int start = static_cast<int>(PyLong_AsLong(slice_item)); int64_t start = static_cast<int64_t>(PyLong_AsLong(slice_item));
auto s_t = start; auto s_t = start;
start = start < 0 ? start + dim_len : start; start = start < 0 ? start + dim_len : start;
......
...@@ -368,7 +368,7 @@ void SetTensorFromPyArrayT( ...@@ -368,7 +368,7 @@ void SetTensorFromPyArrayT(
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int64_t>(array.shape()[i]));
} }
self->Resize(phi::make_ddim(dims)); self->Resize(phi::make_ddim(dims));
...@@ -612,8 +612,8 @@ void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor, ...@@ -612,8 +612,8 @@ void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor,
dims.reserve(array.ndim()); dims.reserve(array.ndim());
int64_t numel = 1; int64_t numel = 1;
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.emplace_back(static_cast<int>(array.shape()[i])); dims.emplace_back(static_cast<int64_t>(array.shape()[i]));
numel *= static_cast<int>(array.shape()[i]); numel *= static_cast<int64_t>(array.shape()[i]);
} }
self_tensor->Resize(phi::make_ddim(dims)); self_tensor->Resize(phi::make_ddim(dims));
......
...@@ -37,9 +37,13 @@ namespace phi { ...@@ -37,9 +37,13 @@ namespace phi {
template <typename T> template <typename T>
struct DegreeFunctor { struct DegreeFunctor {
const T* col_ptr; const T* col_ptr;
HOSTDEVICE explicit inline DegreeFunctor(const T* x) { this->col_ptr = x; } int64_t len_col_ptr;
HOSTDEVICE explicit inline DegreeFunctor(const T* x, int64_t len_col_ptr) {
this->col_ptr = x;
this->len_col_ptr = len_col_ptr;
}
HOSTDEVICE inline int operator()(T i) const { HOSTDEVICE inline int operator()(T i) const {
return col_ptr[i + 1] - col_ptr[i]; return i > len_col_ptr - 1 ? 0 : col_ptr[i + 1] - col_ptr[i];
} }
}; };
...@@ -58,6 +62,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE> ...@@ -58,6 +62,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void SampleKernel(const uint64_t rand_seed, __global__ void SampleKernel(const uint64_t rand_seed,
int k, int k,
const int64_t num_nodes, const int64_t num_nodes,
const int64_t len_col_ptr,
const T* nodes, const T* nodes,
const T* row, const T* row,
const T* col_ptr, const T* col_ptr,
...@@ -88,6 +93,10 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -88,6 +93,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while (out_row < last_row) { while (out_row < last_row) {
T node = nodes[out_row]; T node = nodes[out_row];
if (node > len_col_ptr - 1) {
out_row += BLOCK_WARPS;
continue;
}
T in_row_start = col_ptr[node]; T in_row_start = col_ptr[node];
int deg = col_ptr[node + 1] - in_row_start; int deg = col_ptr[node + 1] - in_row_start;
int out_row_start = output_ptr[out_row]; int out_row_start = output_ptr[out_row];
...@@ -139,10 +148,12 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -139,10 +148,12 @@ __global__ void SampleKernel(const uint64_t rand_seed,
template <typename T, typename Context> template <typename T, typename Context>
int GetTotalSampleNum(const thrust::device_ptr<const T> input, int GetTotalSampleNum(const thrust::device_ptr<const T> input,
const T* col_ptr, const T* col_ptr,
int64_t len_col_ptr,
thrust::device_ptr<int> output_count, thrust::device_ptr<int> output_count,
int sample_size, int sample_size,
int bs) { int bs) {
thrust::transform(input, input + bs, output_count, DegreeFunctor<T>(col_ptr)); thrust::transform(
input, input + bs, output_count, DegreeFunctor<T>(col_ptr, len_col_ptr));
if (sample_size >= 0) { if (sample_size >= 0) {
thrust::transform( thrust::transform(
output_count, output_count + bs, output_count, MaxFunctor(sample_size)); output_count, output_count + bs, output_count, MaxFunctor(sample_size));
...@@ -163,6 +174,7 @@ void SampleNeighbors(const Context& dev_ctx, ...@@ -163,6 +174,7 @@ void SampleNeighbors(const Context& dev_ctx,
int sample_size, int sample_size,
int bs, int bs,
int total_sample_num, int total_sample_num,
int64_t len_col_ptr,
bool return_eids) { bool return_eids) {
thrust::device_vector<int> output_ptr; thrust::device_vector<int> output_ptr;
output_ptr.resize(bs); output_ptr.resize(bs);
...@@ -179,6 +191,7 @@ void SampleNeighbors(const Context& dev_ctx, ...@@ -179,6 +191,7 @@ void SampleNeighbors(const Context& dev_ctx,
0, 0,
sample_size, sample_size,
bs, bs,
len_col_ptr,
thrust::raw_pointer_cast(input), thrust::raw_pointer_cast(input),
row, row,
col_ptr, col_ptr,
...@@ -193,6 +206,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE> ...@@ -193,6 +206,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed, __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k, int k,
const int64_t num_rows, const int64_t num_rows,
const int64_t len_col_ptr,
const T* in_rows, const T* in_rows,
T* src, T* src,
const T* dst_count) { const T* dst_count) {
...@@ -214,6 +228,10 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -214,6 +228,10 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while (out_row < last_row) { while (out_row < last_row) {
const T row = in_rows[out_row]; const T row = in_rows[out_row];
if (row > len_col_ptr - 1) {
out_row += BLOCK_WARPS;
continue;
}
const T in_row_start = dst_count[row]; const T in_row_start = dst_count[row];
const int deg = dst_count[row + 1] - in_row_start; const int deg = dst_count[row + 1] - in_row_start;
int split; int split;
...@@ -312,6 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, ...@@ -312,6 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
int sample_size, int sample_size,
int bs, int bs,
int total_sample_num, int total_sample_num,
int64_t len_col_ptr,
bool return_eids) { bool return_eids) {
thrust::device_vector<int> output_ptr; thrust::device_vector<int> output_ptr;
output_ptr.resize(bs); output_ptr.resize(bs);
...@@ -328,6 +347,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, ...@@ -328,6 +347,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
<<<grid, block, 0, dev_ctx.stream()>>>(0, <<<grid, block, 0, dev_ctx.stream()>>>(0,
sample_size, sample_size,
bs, bs,
len_col_ptr,
thrust::raw_pointer_cast(input), thrust::raw_pointer_cast(input),
perm_data, perm_data,
col_ptr); col_ptr);
...@@ -365,6 +385,7 @@ void GraphSampleNeighborsKernel( ...@@ -365,6 +385,7 @@ void GraphSampleNeighborsKernel(
auto* col_ptr_data = col_ptr.data<T>(); auto* col_ptr_data = col_ptr.data<T>();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
int bs = x.dims()[0]; int bs = x.dims()[0];
int64_t len_col_ptr = col_ptr.dims()[0];
const thrust::device_ptr<const T> input(x_data); const thrust::device_ptr<const T> input(x_data);
...@@ -373,7 +394,7 @@ void GraphSampleNeighborsKernel( ...@@ -373,7 +394,7 @@ void GraphSampleNeighborsKernel(
thrust::device_ptr<int> output_count(out_count_data); thrust::device_ptr<int> output_count(out_count_data);
int total_sample_size = GetTotalSampleNum<T, Context>( int total_sample_size = GetTotalSampleNum<T, Context>(
input, col_ptr_data, output_count, sample_size, bs); input, col_ptr_data, len_col_ptr, output_count, sample_size, bs);
out->Resize({static_cast<int>(total_sample_size)}); out->Resize({static_cast<int>(total_sample_size)});
T* out_data = dev_ctx.template Alloc<T>(out); T* out_data = dev_ctx.template Alloc<T>(out);
...@@ -396,6 +417,7 @@ void GraphSampleNeighborsKernel( ...@@ -396,6 +417,7 @@ void GraphSampleNeighborsKernel(
sample_size, sample_size,
bs, bs,
total_sample_size, total_sample_size,
len_col_ptr,
return_eids); return_eids);
} else { } else {
DenseTensor perm_buffer_out(perm_buffer->type()); DenseTensor perm_buffer_out(perm_buffer->type());
...@@ -414,6 +436,7 @@ void GraphSampleNeighborsKernel( ...@@ -414,6 +436,7 @@ void GraphSampleNeighborsKernel(
sample_size, sample_size,
bs, bs,
total_sample_size, total_sample_size,
len_col_ptr,
return_eids); return_eids);
} }
} else { } else {
...@@ -431,6 +454,7 @@ void GraphSampleNeighborsKernel( ...@@ -431,6 +454,7 @@ void GraphSampleNeighborsKernel(
sample_size, sample_size,
bs, bs,
total_sample_size, total_sample_size,
len_col_ptr,
return_eids); return_eids);
} else { } else {
DenseTensor perm_buffer_out(perm_buffer->type()); DenseTensor perm_buffer_out(perm_buffer->type());
...@@ -449,6 +473,7 @@ void GraphSampleNeighborsKernel( ...@@ -449,6 +473,7 @@ void GraphSampleNeighborsKernel(
sample_size, sample_size,
bs, bs,
total_sample_size, total_sample_size,
len_col_ptr,
return_eids); return_eids);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册