未验证 提交 c31dd04c 编写于 作者: S Siming Dai 提交者: GitHub

Optimization for graph_sample_neighbors API (#41447)

* add eids result for graph_sample_neighbors

* fix bug

* move fisher_yates sample to warp

* add cpu eid output

* delete comment

* delete comment

* change nullptr placeholder

* optimize sample kernel

* fix mutable_data
上级 9f06069d
...@@ -39,17 +39,42 @@ void SampleUniqueNeighbors( ...@@ -39,17 +39,42 @@ void SampleUniqueNeighbors(
} }
} }
template <class bidiiter>
void SampleUniqueNeighborsWithEids(
bidiiter src_begin,
bidiiter src_end,
bidiiter eid_begin,
bidiiter eid_end,
int num_samples,
std::mt19937& rng,
std::uniform_int_distribution<int>& dice_distribution) {
int left_num = std::distance(src_begin, src_end);
for (int i = 0; i < num_samples; i++) {
bidiiter r1 = src_begin, r2 = eid_begin;
int random_step = dice_distribution(rng) % left_num;
std::advance(r1, random_step);
std::advance(r2, random_step);
std::swap(*src_begin, *r1);
std::swap(*eid_begin, *r2);
++src_begin;
++eid_begin;
--left_num;
}
}
template <typename T> template <typename T>
void SampleNeighbors(const T* row, void SampleNeighbors(const T* row,
const T* col_ptr, const T* col_ptr,
const T* eids,
const T* input, const T* input,
std::vector<T>* output, std::vector<T>* output,
std::vector<int>* output_count, std::vector<int>* output_count,
std::vector<T>* output_eids,
int sample_size, int sample_size,
int bs) { int bs,
// Allocate the memory of output bool return_eids) {
// Collect the neighbors size
std::vector<std::vector<T>> out_src_vec; std::vector<std::vector<T>> out_src_vec;
std::vector<std::vector<T>> out_eids_vec;
// `sample_cumsum_sizes` record the start position and end position // `sample_cumsum_sizes` record the start position and end position
// after sampling. // after sampling.
std::vector<int> sample_cumsum_sizes(bs + 1); std::vector<int> sample_cumsum_sizes(bs + 1);
...@@ -65,10 +90,18 @@ void SampleNeighbors(const T* row, ...@@ -65,10 +90,18 @@ void SampleNeighbors(const T* row,
std::vector<T> out_src; std::vector<T> out_src;
out_src.resize(cap); out_src.resize(cap);
out_src_vec.emplace_back(out_src); out_src_vec.emplace_back(out_src);
if (return_eids) {
std::vector<T> out_eids;
out_eids.resize(cap);
out_eids_vec.emplace_back(out_eids);
}
} }
output_count->resize(bs); output_count->resize(bs);
output->resize(total_neighbors); output->resize(total_neighbors);
if (return_eids) {
output_eids->resize(total_neighbors);
}
std::random_device rd; std::random_device rd;
std::mt19937 rng{rd()}; std::mt19937 rng{rd()};
...@@ -85,15 +118,28 @@ void SampleNeighbors(const T* row, ...@@ -85,15 +118,28 @@ void SampleNeighbors(const T* row,
int cap = end - begin; int cap = end - begin;
if (sample_size < cap) { if (sample_size < cap) {
std::copy(row + begin, row + end, out_src_vec[i].begin()); std::copy(row + begin, row + end, out_src_vec[i].begin());
// TODO(daisiming): Check whether is correct. if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
SampleUniqueNeighborsWithEids(out_src_vec[i].begin(),
out_src_vec[i].end(),
out_eids_vec[i].begin(),
out_eids_vec[i].end(),
sample_size,
rng,
dice_distribution);
} else {
SampleUniqueNeighbors(out_src_vec[i].begin(), SampleUniqueNeighbors(out_src_vec[i].begin(),
out_src_vec[i].end(), out_src_vec[i].end(),
sample_size, sample_size,
rng, rng,
dice_distribution); dice_distribution);
}
*(output_count->data() + i) = sample_size; *(output_count->data() + i) = sample_size;
} else { } else {
std::copy(row + begin, row + end, out_src_vec[i].begin()); std::copy(row + begin, row + end, out_src_vec[i].begin());
if (return_eids) {
std::copy(eids + begin, eids + end, out_eids_vec[i].begin());
}
*(output_count->data() + i) = cap; *(output_count->data() + i) = cap;
} }
} }
...@@ -107,6 +153,11 @@ void SampleNeighbors(const T* row, ...@@ -107,6 +153,11 @@ void SampleNeighbors(const T* row,
std::copy(out_src_vec[i].begin(), std::copy(out_src_vec[i].begin(),
out_src_vec[i].begin() + k, out_src_vec[i].begin() + k,
output->data() + sample_cumsum_sizes[i]); output->data() + sample_cumsum_sizes[i]);
if (return_eids) {
std::copy(out_eids_vec[i].begin(),
out_eids_vec[i].begin() + k,
output_eids->data() + sample_cumsum_sizes[i]);
}
} }
} }
...@@ -131,8 +182,35 @@ void GraphSampleNeighborsKernel( ...@@ -131,8 +182,35 @@ void GraphSampleNeighborsKernel(
std::vector<T> output; std::vector<T> output;
std::vector<int> output_count; std::vector<int> output_count;
SampleNeighbors<T>(
row_data, col_ptr_data, x_data, &output, &output_count, sample_size, bs); if (return_eids) {
const T* eids_data = eids.get_ptr()->data<T>();
std::vector<T> output_eids;
SampleNeighbors<T>(row_data,
col_ptr_data,
eids_data,
x_data,
&output,
&output_count,
&output_eids,
sample_size,
bs,
return_eids);
out_eids->Resize({static_cast<int>(output_eids.size())});
T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
std::copy(output_eids.begin(), output_eids.end(), out_eids_data);
} else {
SampleNeighbors<T>(row_data,
col_ptr_data,
nullptr,
x_data,
&output,
&output_count,
nullptr,
sample_size,
bs,
return_eids);
}
out->Resize({static_cast<int>(output.size())}); out->Resize({static_cast<int>(output.size())});
T* out_data = dev_ctx.template Alloc<T>(out); T* out_data = dev_ctx.template Alloc<T>(out);
std::copy(output.begin(), output.end(), out_data); std::copy(output.begin(), output.end(), out_data);
......
...@@ -62,9 +62,11 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -62,9 +62,11 @@ __global__ void SampleKernel(const uint64_t rand_seed,
const T* nodes, const T* nodes,
const T* row, const T* row,
const T* col_ptr, const T* col_ptr,
const T* eids,
T* output, T* output,
T* output_eids,
int* output_ptr, int* output_ptr,
int* output_idxs) { bool return_eids) {
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS); assert(blockDim.y == BLOCK_WARPS);
...@@ -94,10 +96,13 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -94,10 +96,13 @@ __global__ void SampleKernel(const uint64_t rand_seed,
if (deg <= k) { if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
output[out_row_start + idx] = row[in_row_start + idx]; output[out_row_start + idx] = row[in_row_start + idx];
if (return_eids) {
output_eids[out_row_start + idx] = eids[in_row_start + idx];
}
} }
} else { } else {
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
output_idxs[out_row_start + idx] = idx; output[out_row_start + idx] = idx;
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
__syncwarp(); __syncwarp();
...@@ -111,7 +116,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -111,7 +116,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
#endif #endif
if (num < k) { if (num < k) {
atomicMax(reinterpret_cast<unsigned int*>( // NOLINT atomicMax(reinterpret_cast<unsigned int*>( // NOLINT
output_idxs + out_row_start + num), output + out_row_start + num),
static_cast<unsigned int>(idx)); // NOLINT static_cast<unsigned int>(idx)); // NOLINT
} }
} }
...@@ -120,8 +125,11 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -120,8 +125,11 @@ __global__ void SampleKernel(const uint64_t rand_seed,
#endif #endif
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
T perm_idx = output_idxs[out_row_start + idx] + in_row_start; T perm_idx = output[out_row_start + idx] + in_row_start;
output[out_row_start + idx] = row[perm_idx]; output[out_row_start + idx] = row[perm_idx];
if (return_eids) {
output_eids[out_row_start + idx] = eids[perm_idx];
}
} }
} }
...@@ -148,16 +156,17 @@ template <typename T, typename Context> ...@@ -148,16 +156,17 @@ template <typename T, typename Context>
void SampleNeighbors(const Context& dev_ctx, void SampleNeighbors(const Context& dev_ctx,
const T* row, const T* row,
const T* col_ptr, const T* col_ptr,
const T* eids,
const thrust::device_ptr<const T> input, const thrust::device_ptr<const T> input,
thrust::device_ptr<T> output, thrust::device_ptr<T> output,
thrust::device_ptr<int> output_count, thrust::device_ptr<int> output_count,
thrust::device_ptr<T> output_eids,
int sample_size, int sample_size,
int bs, int bs,
int total_sample_num) { int total_sample_num,
bool return_eids) {
thrust::device_vector<int> output_ptr; thrust::device_vector<int> output_ptr;
thrust::device_vector<int> output_idxs;
output_ptr.resize(bs); output_ptr.resize(bs);
output_idxs.resize(total_sample_num);
thrust::exclusive_scan( thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0); output_count, output_count + bs, output_ptr.begin(), 0);
...@@ -176,18 +185,26 @@ void SampleNeighbors(const Context& dev_ctx, ...@@ -176,18 +185,26 @@ void SampleNeighbors(const Context& dev_ctx,
thrust::raw_pointer_cast(input), thrust::raw_pointer_cast(input),
row, row,
col_ptr, col_ptr,
eids,
thrust::raw_pointer_cast(output), thrust::raw_pointer_cast(output),
thrust::raw_pointer_cast(output_eids),
thrust::raw_pointer_cast(output_ptr.data()), thrust::raw_pointer_cast(output_ptr.data()),
thrust::raw_pointer_cast(output_idxs.data())); return_eids);
} }
template <typename T> template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed, __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k, int k,
const int64_t num_rows, const int64_t num_rows,
const T* in_rows, const T* in_rows,
T* src, T* src,
const T* dst_count) { const T* dst_count) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hiprandState rng; hiprandState rng;
hiprand_init( hiprand_init(
...@@ -197,20 +214,19 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -197,20 +214,19 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
curand_init( curand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif #endif
CUDA_KERNEL_LOOP(out_row, num_rows) {
while (out_row < last_row) {
const T row = in_rows[out_row]; const T row = in_rows[out_row];
const T in_row_start = dst_count[row]; const T in_row_start = dst_count[row];
const int deg = dst_count[row + 1] - in_row_start; const int deg = dst_count[row + 1] - in_row_start;
int split; int split;
T tmp;
if (k < deg) { if (k < deg) {
if (deg < 2 * k) { if (deg < 2 * k) {
split = k; split = k;
} else { } else {
split = deg - k; split = deg - k;
} }
for (int idx = deg - 1; idx >= split; idx--) { for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1); const int num = hiprand(&rng) % (idx + 1);
#else #else
...@@ -222,7 +238,11 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -222,7 +238,11 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
static_cast<unsigned long long int>( // NOLINT static_cast<unsigned long long int>( // NOLINT
src[in_row_start + idx]))); src[in_row_start + idx])));
} }
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
} }
out_row += BLOCK_WARPS;
} }
} }
...@@ -232,9 +252,12 @@ __global__ void GatherEdge(int k, ...@@ -232,9 +252,12 @@ __global__ void GatherEdge(int k,
const T* in_rows, const T* in_rows,
const T* src, const T* src,
const T* dst_count, const T* dst_count,
const T* eids,
T* outputs, T* outputs,
T* output_eids,
int* output_ptr, int* output_ptr,
T* perm_data) { T* perm_data,
bool return_eids) {
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS); assert(blockDim.y == BLOCK_WARPS);
...@@ -250,8 +273,10 @@ __global__ void GatherEdge(int k, ...@@ -250,8 +273,10 @@ __global__ void GatherEdge(int k,
if (deg <= k) { if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
const T in_idx = in_row_start + idx; outputs[out_row_start + idx] = src[in_row_start + idx];
outputs[out_row_start + idx] = src[in_idx]; if (return_eids) {
output_eids[out_row_start + idx] = eids[in_row_start + idx];
}
} }
} else { } else {
int split = k; int split = k;
...@@ -267,6 +292,10 @@ __global__ void GatherEdge(int k, ...@@ -267,6 +292,10 @@ __global__ void GatherEdge(int k,
for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) { for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) {
outputs[out_row_start + idx - begin] = outputs[out_row_start + idx - begin] =
src[perm_data[in_row_start + idx]]; src[perm_data[in_row_start + idx]];
if (return_eids) {
output_eids[out_row_start + idx - begin] =
eids[perm_data[in_row_start + idx]];
}
} }
} }
out_row += BLOCK_WARPS; out_row += BLOCK_WARPS;
...@@ -277,49 +306,48 @@ template <typename T, typename Context> ...@@ -277,49 +306,48 @@ template <typename T, typename Context>
void FisherYatesSampleNeighbors(const Context& dev_ctx, void FisherYatesSampleNeighbors(const Context& dev_ctx,
const T* row, const T* row,
const T* col_ptr, const T* col_ptr,
const T* eids,
T* perm_data, T* perm_data,
const thrust::device_ptr<const T> input, const thrust::device_ptr<const T> input,
thrust::device_ptr<T> output, thrust::device_ptr<T> output,
thrust::device_ptr<int> output_count, thrust::device_ptr<int> output_count,
thrust::device_ptr<T> output_eids,
int sample_size, int sample_size,
int bs, int bs,
int total_sample_num) { int total_sample_num,
bool return_eids) {
thrust::device_vector<int> output_ptr; thrust::device_vector<int> output_ptr;
output_ptr.resize(bs); output_ptr.resize(bs);
thrust::exclusive_scan( thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0); output_count, output_count + bs, output_ptr.begin(), 0);
#ifdef PADDLE_WITH_HIP constexpr int WARP_SIZE = 32;
int block = 256; constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
#else constexpr int TILE_SIZE = BLOCK_WARPS * 16;
int block = 1024; const dim3 block(WARP_SIZE, BLOCK_WARPS);
#endif const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0];
int grid_tmp = (bs + block - 1) / block;
int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
FisherYatesSampleKernel<T><<<grid, block, 0, dev_ctx.stream()>>>( FisherYatesSampleKernel<T,
WARP_SIZE,
BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, dev_ctx.stream()>>>(
0, sample_size, bs, thrust::raw_pointer_cast(input), perm_data, col_ptr); 0, sample_size, bs, thrust::raw_pointer_cast(input), perm_data, col_ptr);
constexpr int GATHER_WARP_SIZE = 32; GatherEdge<T,
constexpr int GATHER_BLOCK_WARPS = 128 / GATHER_WARP_SIZE; WARP_SIZE,
constexpr int GATHER_TILE_SIZE = GATHER_BLOCK_WARPS * 16; BLOCK_WARPS,
const dim3 gather_block(GATHER_WARP_SIZE, GATHER_BLOCK_WARPS); TILE_SIZE><<<grid, block, 0, dev_ctx.stream()>>>(
const dim3 gather_grid((bs + GATHER_TILE_SIZE - 1) / GATHER_TILE_SIZE);
GatherEdge<
T,
GATHER_WARP_SIZE,
GATHER_BLOCK_WARPS,
GATHER_TILE_SIZE><<<gather_grid, gather_block, 0, dev_ctx.stream()>>>(
sample_size, sample_size,
bs, bs,
thrust::raw_pointer_cast(input), thrust::raw_pointer_cast(input),
row, row,
col_ptr, col_ptr,
eids,
thrust::raw_pointer_cast(output), thrust::raw_pointer_cast(output),
thrust::raw_pointer_cast(output_eids),
thrust::raw_pointer_cast(output_ptr.data()), thrust::raw_pointer_cast(output_ptr.data()),
perm_data); perm_data,
return_eids);
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -354,32 +382,78 @@ void GraphSampleNeighborsKernel( ...@@ -354,32 +382,78 @@ void GraphSampleNeighborsKernel(
T* out_data = dev_ctx.template Alloc<T>(out); T* out_data = dev_ctx.template Alloc<T>(out);
thrust::device_ptr<T> output(out_data); thrust::device_ptr<T> output(out_data);
if (return_eids) {
auto* eids_data = eids.get_ptr()->data<T>();
out_eids->Resize({static_cast<int>(total_sample_size)});
T* out_eids_data = dev_ctx.template Alloc<T>(out_eids);
thrust::device_ptr<T> output_eids(out_eids_data);
if (!flag_perm_buffer) {
SampleNeighbors<T, Context>(dev_ctx,
row_data,
col_ptr_data,
eids_data,
input,
output,
output_count,
output_eids,
sample_size,
bs,
total_sample_size,
return_eids);
} else {
DenseTensor perm_buffer_out(perm_buffer->type());
const auto* p_perm_buffer = perm_buffer.get_ptr();
perm_buffer_out.ShareDataWith(*p_perm_buffer);
T* perm_buffer_out_data = perm_buffer_out.template data<T>();
FisherYatesSampleNeighbors<T, Context>(dev_ctx,
row_data,
col_ptr_data,
eids_data,
perm_buffer_out_data,
input,
output,
output_count,
output_eids,
sample_size,
bs,
total_sample_size,
return_eids);
}
} else {
// How to set null value for output_eids(thrust::device_ptr<T>)?
// We use `output` to fill the position of unused output_eids.
if (!flag_perm_buffer) { if (!flag_perm_buffer) {
SampleNeighbors<T, Context>(dev_ctx, SampleNeighbors<T, Context>(dev_ctx,
row_data, row_data,
col_ptr_data, col_ptr_data,
nullptr,
input, input,
output, output,
output_count, output_count,
output,
sample_size, sample_size,
bs, bs,
total_sample_size); total_sample_size,
return_eids);
} else { } else {
DenseTensor perm_buffer_out(perm_buffer->type()); DenseTensor perm_buffer_out(perm_buffer->type());
const auto* p_perm_buffer = perm_buffer.get_ptr(); const auto* p_perm_buffer = perm_buffer.get_ptr();
perm_buffer_out.ShareDataWith(*p_perm_buffer); perm_buffer_out.ShareDataWith(*p_perm_buffer);
T* perm_buffer_out_data = T* perm_buffer_out_data = perm_buffer_out.template data<T>();
perm_buffer_out.mutable_data<T>(dev_ctx.GetPlace());
FisherYatesSampleNeighbors<T, Context>(dev_ctx, FisherYatesSampleNeighbors<T, Context>(dev_ctx,
row_data, row_data,
col_ptr_data, col_ptr_data,
nullptr,
perm_buffer_out_data, perm_buffer_out_data,
input, input,
output, output,
output_count, output_count,
output,
sample_size, sample_size,
bs, bs,
total_sample_size); total_sample_size,
return_eids);
}
} }
} }
......
...@@ -162,14 +162,14 @@ class TestGraphSampleNeighbors(unittest.TestCase): ...@@ -162,14 +162,14 @@ class TestGraphSampleNeighbors(unittest.TestCase):
self.assertRaises(ValueError, check_perm_buffer_error) self.assertRaises(ValueError, check_perm_buffer_error)
def test_sample_result_with_eids(self): def test_sample_result_with_eids(self):
# Note: Currently return eid results is not initialized.
paddle.disable_static() paddle.disable_static()
row = paddle.to_tensor(self.row) row = paddle.to_tensor(self.row)
colptr = paddle.to_tensor(self.colptr) colptr = paddle.to_tensor(self.colptr)
nodes = paddle.to_tensor(self.nodes) nodes = paddle.to_tensor(self.nodes)
eids = paddle.to_tensor(self.edges_id) eids = paddle.to_tensor(self.edges_id)
perm_buffer = paddle.to_tensor(self.edges_id)
out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors( out_neighbors, out_count, out_eids = paddle.incubate.graph_sample_neighbors(
row, row,
colptr, colptr,
nodes, nodes,
...@@ -177,6 +177,16 @@ class TestGraphSampleNeighbors(unittest.TestCase): ...@@ -177,6 +177,16 @@ class TestGraphSampleNeighbors(unittest.TestCase):
sample_size=self.sample_size, sample_size=self.sample_size,
return_eids=True) return_eids=True)
out_neighbors, out_count, out_eids = paddle.incubate.graph_sample_neighbors(
row,
colptr,
nodes,
eids=eids,
perm_buffer=perm_buffer,
sample_size=self.sample_size,
return_eids=True,
flag_perm_buffer=True)
paddle.enable_static() paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
row = paddle.static.data( row = paddle.static.data(
...@@ -188,7 +198,7 @@ class TestGraphSampleNeighbors(unittest.TestCase): ...@@ -188,7 +198,7 @@ class TestGraphSampleNeighbors(unittest.TestCase):
eids = paddle.static.data( eids = paddle.static.data(
name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype) name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype)
out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors( out_neighbors, out_count, out_eids = paddle.incubate.graph_sample_neighbors(
row, row,
colptr, colptr,
nodes, nodes,
...@@ -202,7 +212,7 @@ class TestGraphSampleNeighbors(unittest.TestCase): ...@@ -202,7 +212,7 @@ class TestGraphSampleNeighbors(unittest.TestCase):
'nodes': self.nodes, 'nodes': self.nodes,
'eids': self.edges_id 'eids': self.edges_id
}, },
fetch_list=[out_neighbors, out_count]) fetch_list=[out_neighbors, out_count, out_eids])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册