未验证 提交 59be2f3b 编写于 作者: S Siming Dai 提交者: GitHub

[GNN] Fix graph sample and data type bug (#45001)

上级 125e48c3
......@@ -670,7 +670,7 @@ void BindImperative(py::module *m_ptr) {
.def("__init__",
[](imperative::VarBase &self,
framework::proto::VarType::Type dtype,
const std::vector<int> &dims,
const std::vector<int64_t> &dims,
const py::handle &name,
framework::proto::VarType::Type type,
bool persistable) {
......
......@@ -191,10 +191,10 @@ static void ParseIndexingSlice(framework::LoDTensor* tensor,
PyObject* slice_item = PyTuple_GetItem(index, i);
infer_flags->push_back(1);
int dim_len = shape[dim];
int64_t dim_len = shape[dim];
if (PyCheckInteger(slice_item) || IsNumpyType(slice_item)) {
// integer, PyLong_AsLong supports both int and long
int start = static_cast<int>(PyLong_AsLong(slice_item));
int64_t start = static_cast<int64_t>(PyLong_AsLong(slice_item));
auto s_t = start;
start = start < 0 ? start + dim_len : start;
......
......@@ -368,7 +368,7 @@ void SetTensorFromPyArrayT(
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
dims.push_back(static_cast<int64_t>(array.shape()[i]));
}
self->Resize(phi::make_ddim(dims));
......@@ -612,8 +612,8 @@ void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor,
dims.reserve(array.ndim());
int64_t numel = 1;
for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.emplace_back(static_cast<int>(array.shape()[i]));
numel *= static_cast<int>(array.shape()[i]);
dims.emplace_back(static_cast<int64_t>(array.shape()[i]));
numel *= static_cast<int64_t>(array.shape()[i]);
}
self_tensor->Resize(phi::make_ddim(dims));
......
......@@ -37,9 +37,13 @@ namespace phi {
template <typename T>
struct DegreeFunctor {
const T* col_ptr;
HOSTDEVICE explicit inline DegreeFunctor(const T* x) { this->col_ptr = x; }
int64_t len_col_ptr;
HOSTDEVICE explicit inline DegreeFunctor(const T* x, int64_t len_col_ptr) {
this->col_ptr = x;
this->len_col_ptr = len_col_ptr;
}
HOSTDEVICE inline int operator()(T i) const {
return col_ptr[i + 1] - col_ptr[i];
return i > len_col_ptr - 1 ? 0 : col_ptr[i + 1] - col_ptr[i];
}
};
......@@ -58,6 +62,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void SampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_nodes,
const int64_t len_col_ptr,
const T* nodes,
const T* row,
const T* col_ptr,
......@@ -88,6 +93,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while (out_row < last_row) {
T node = nodes[out_row];
if (node > len_col_ptr - 1) {
out_row += BLOCK_WARPS;
continue;
}
T in_row_start = col_ptr[node];
int deg = col_ptr[node + 1] - in_row_start;
int out_row_start = output_ptr[out_row];
......@@ -139,10 +148,12 @@ __global__ void SampleKernel(const uint64_t rand_seed,
template <typename T, typename Context>
int GetTotalSampleNum(const thrust::device_ptr<const T> input,
const T* col_ptr,
int64_t len_col_ptr,
thrust::device_ptr<int> output_count,
int sample_size,
int bs) {
thrust::transform(input, input + bs, output_count, DegreeFunctor<T>(col_ptr));
thrust::transform(
input, input + bs, output_count, DegreeFunctor<T>(col_ptr, len_col_ptr));
if (sample_size >= 0) {
thrust::transform(
output_count, output_count + bs, output_count, MaxFunctor(sample_size));
......@@ -163,6 +174,7 @@ void SampleNeighbors(const Context& dev_ctx,
int sample_size,
int bs,
int total_sample_num,
int64_t len_col_ptr,
bool return_eids) {
thrust::device_vector<int> output_ptr;
output_ptr.resize(bs);
......@@ -179,6 +191,7 @@ void SampleNeighbors(const Context& dev_ctx,
0,
sample_size,
bs,
len_col_ptr,
thrust::raw_pointer_cast(input),
row,
col_ptr,
......@@ -193,6 +206,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_rows,
const int64_t len_col_ptr,
const T* in_rows,
T* src,
const T* dst_count) {
......@@ -214,6 +228,10 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while (out_row < last_row) {
const T row = in_rows[out_row];
if (row > len_col_ptr - 1) {
out_row += BLOCK_WARPS;
continue;
}
const T in_row_start = dst_count[row];
const int deg = dst_count[row + 1] - in_row_start;
int split;
......@@ -312,6 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
int sample_size,
int bs,
int total_sample_num,
int64_t len_col_ptr,
bool return_eids) {
thrust::device_vector<int> output_ptr;
output_ptr.resize(bs);
......@@ -328,6 +347,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
<<<grid, block, 0, dev_ctx.stream()>>>(0,
sample_size,
bs,
len_col_ptr,
thrust::raw_pointer_cast(input),
perm_data,
col_ptr);
......@@ -365,6 +385,7 @@ void GraphSampleNeighborsKernel(
auto* col_ptr_data = col_ptr.data<T>();
auto* x_data = x.data<T>();
int bs = x.dims()[0];
int64_t len_col_ptr = col_ptr.dims()[0];
const thrust::device_ptr<const T> input(x_data);
......@@ -373,7 +394,7 @@ void GraphSampleNeighborsKernel(
thrust::device_ptr<int> output_count(out_count_data);
int total_sample_size = GetTotalSampleNum<T, Context>(
input, col_ptr_data, output_count, sample_size, bs);
input, col_ptr_data, len_col_ptr, output_count, sample_size, bs);
out->Resize({static_cast<int>(total_sample_size)});
T* out_data = dev_ctx.template Alloc<T>(out);
......@@ -396,6 +417,7 @@ void GraphSampleNeighborsKernel(
sample_size,
bs,
total_sample_size,
len_col_ptr,
return_eids);
} else {
DenseTensor perm_buffer_out(perm_buffer->type());
......@@ -414,6 +436,7 @@ void GraphSampleNeighborsKernel(
sample_size,
bs,
total_sample_size,
len_col_ptr,
return_eids);
}
} else {
......@@ -431,6 +454,7 @@ void GraphSampleNeighborsKernel(
sample_size,
bs,
total_sample_size,
len_col_ptr,
return_eids);
} else {
DenseTensor perm_buffer_out(perm_buffer->type());
......@@ -449,6 +473,7 @@ void GraphSampleNeighborsKernel(
sample_size,
bs,
total_sample_size,
len_col_ptr,
return_eids);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册