未验证 提交 78fad09b 编写于 作者: S Siming Dai 提交者: GitHub

[GpuPs] Update graph sampling method (#40085)

* gpu ps graph engine

* remove logs

* Add neighbor sampling method

* Add actual_sample_size and offset for sampling

* Delete Chinese comment

* Fix code style
Co-authored-by: bstw111's avatarseemingwang <zsasuke@qq.com>
上级 564dcd52
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <thrust/host_vector.h>
#include "heter_comm.h" #include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
...@@ -40,11 +41,13 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> { ...@@ -40,11 +41,13 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int sample_size, int len); int sample_size, int len);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size); NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info(); void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, void move_neighbor_sample_result_to_source_gpu(
int sample_size, int *h_left, int gpu_id, int gpu_num, int *h_left, int *h_right,
int *h_right, int64_t *src_sample_res, thrust::host_vector<int> &total_sample_size);
int64_t *src_sample_res, void move_neighbor_sample_size_to_source_gpu(int gpu_id, int gpu_num,
int *actual_sample_size); int *h_left, int *h_right,
int *actual_sample_size,
int *total_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph); int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param); int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() { virtual int32_t end_graph_sampling() {
......
...@@ -13,10 +13,23 @@ ...@@ -13,10 +13,23 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" //#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
constexpr int WARP_SIZE = 32;
/* /*
comment 0 comment 0
this kernel just serves as an example of how to sample nodes' neighbors. this kernel just serves as an example of how to sample nodes' neighbors.
...@@ -29,20 +42,79 @@ sample_size; ...@@ -29,20 +42,79 @@ sample_size;
*/ */
__global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index, struct MaxFunctor {
int* actual_size, int sample_size;
int64_t* sample_result, int sample_size, HOSTDEVICE explicit inline MaxFunctor(int sample_size) {
int len) { this->sample_size = sample_size;
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; }
if (i < len) { HOSTDEVICE inline int operator()(int x) const {
if (x > sample_size) {
return sample_size;
}
return x;
}
};
struct DegreeFunctor {
GpuPsCommGraph graph;
HOSTDEVICE explicit inline DegreeFunctor(GpuPsCommGraph graph) {
this->graph = graph;
}
HOSTDEVICE inline int operator()(int i) const {
return graph.node_list[i].neighbor_size;
}
};
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample(const uint64_t rand_seed, GpuPsCommGraph graph,
int sample_size, int* index, int len,
int64_t* sample_result, int* output_idx,
int* output_offset) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
curandState rng;
curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);
while (i < last_idx) {
auto node_index = index[i]; auto node_index = index[i];
actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size int degree = graph.node_list[node_index].neighbor_size;
? graph.node_list[node_index].neighbor_size const int offset = graph.node_list[node_index].neighbor_offset;
: sample_size; int output_start = output_offset[i];
int offset = graph.node_list[node_index].neighbor_offset;
for (int j = 0; j < actual_size[i]; j++) { if (degree <= sample_size) {
sample_result[sample_size * i + j] = graph.neighbor_list[offset + j]; // Just copy
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
sample_result[output_start + j] = graph.neighbor_list[offset + j];
}
} else {
for (int j = threadIdx.x; j < degree; j += WARP_SIZE) {
output_idx[output_start + j] = j;
}
__syncwarp();
for (int j = sample_size + threadIdx.x; j < degree; j += WARP_SIZE) {
const int num = curand(&rng) % (j + 1);
if (num < sample_size) {
atomicMax(
reinterpret_cast<unsigned int*>(output_idx + output_start + num),
static_cast<unsigned int>(j));
}
}
__syncwarp();
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
const int perm_idx = output_idx[output_start + j] + offset;
sample_result[output_start + j] = graph.neighbor_list[perm_idx];
}
} }
i += BLOCK_WARPS;
} }
} }
...@@ -79,7 +151,7 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { ...@@ -79,7 +151,7 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
gpu i triggers a neighbor_sample task, gpu i triggers a neighbor_sample task,
when this task is done, when this task is done,
this function is called to move the sample result on other gpu back this function is called to move the sample result on other gpu back
to gup i and aggragate the result. to gpu i and aggragate the result.
the sample_result is saved on src_sample_res and the actual sample size for the sample_result is saved on src_sample_res and the actual sample size for
each node is saved on actual_sample_size. each node is saved on actual_sample_size.
the number of actual sample_result for the number of actual sample_result for
...@@ -96,10 +168,50 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) { ...@@ -96,10 +168,50 @@ int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
that's what fill_dvals does. that's what fill_dvals does.
*/ */
void GpuPsGraphTable::move_neighbor_sample_size_to_source_gpu(
int gpu_id, int gpu_num, int* h_left, int* h_right, int* actual_sample_size,
int* total_sample_size) {
// This function copyed actual_sample_size to source_gpu,
// and calculate total_sample_size of each gpu sample number.
for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto shard_len = h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
}
for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
total_sample_size[i] = 0;
continue;
}
auto& node = path_[gpu_id][i].nodes_.front();
cudaStreamSynchronize(node.out_stream);
auto shard_len = h_right[i] - h_left[i] + 1;
thrust::device_vector<int> t_actual_sample_size(shard_len);
thrust::copy(actual_sample_size + h_left[i],
actual_sample_size + h_left[i] + shard_len,
t_actual_sample_size.begin());
total_sample_size[i] = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
}
}
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right, int gpu_id, int gpu_num, int* h_left, int* h_right, int64_t* src_sample_res,
int64_t* src_sample_res, int* actual_sample_size) { thrust::host_vector<int>& total_sample_size) {
/*
if total_sample_size is [4, 5, 1, 6],
then cumsum_total_sample_size is [0, 4, 9, 10];
*/
thrust::host_vector<int> cumsum_total_sample_size(gpu_num, 0);
thrust::exclusive_scan(total_sample_size.begin(), total_sample_size.end(),
cumsum_total_sample_size.begin(), 0);
for (int i = 0; i < gpu_num; i++) { for (int i = 0; i < gpu_num; i++) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
continue; continue;
...@@ -109,14 +221,10 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( ...@@ -109,14 +221,10 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
// auto& node = path_[gpu_id][i].nodes_[cur_step]; // auto& node = path_[gpu_id][i].nodes_[cur_step];
auto& node = path_[gpu_id][i].nodes_.front(); auto& node = path_[gpu_id][i].nodes_.front();
cudaMemcpyAsync( cudaMemcpyAsync(
reinterpret_cast<char*>(src_sample_res + h_left[i] * sample_size), reinterpret_cast<char*>(src_sample_res + cumsum_total_sample_size[i]),
node.val_storage + sizeof(int64_t) * shard_len, node.val_storage + sizeof(int64_t) * shard_len,
node.val_bytes_len - sizeof(int64_t) * shard_len, cudaMemcpyDefault, sizeof(int64_t) * total_sample_size[i], cudaMemcpyDefault,
node.out_stream); node.out_stream);
cudaMemcpyAsync(reinterpret_cast<char*>(actual_sample_size + h_left[i]),
node.val_storage + sizeof(int) * shard_len,
sizeof(int) * shard_len, cudaMemcpyDefault,
node.out_stream);
} }
for (int i = 0; i < gpu_num; ++i) { for (int i = 0; i < gpu_num; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) { if (h_left[i] == -1 || h_right[i] == -1) {
...@@ -131,17 +239,35 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( ...@@ -131,17 +239,35 @@ void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
TODO: TODO:
how to optimize it to eliminate the for loop how to optimize it to eliminate the for loop
*/ */
__global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals, __global__ void fill_dvalues_actual_sample_size(int* d_shard_actual_sample_size,
int* d_shard_actual_sample_size, int* d_actual_sample_size,
int* d_actual_sample_size, int* idx, int* idx, int len) {
int sample_size, int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x; const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) { if (i < len) {
d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i]; d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i];
// d_vals[idx[i]] = d_shard_vals[i]; }
for (int j = 0; j < sample_size; j++) { }
d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j];
template <int BLOCK_WARPS, int TILE_SIZE>
__global__ void fill_dvalues_sample_result(int64_t* d_shard_vals,
int64_t* d_vals,
int* d_actual_sample_size, int* idx,
int* offset, int* d_offset,
int len) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, len);
while (i < last_idx) {
const int sample_size = d_actual_sample_size[idx[i]];
for (int j = threadIdx.x; j < sample_size; j += WARP_SIZE) {
d_vals[offset[idx[i]] + j] = d_shard_vals[d_offset[i] + j];
} }
#ifdef PADDLE_WITH_CUDA
__syncwarp();
#endif
i += BLOCK_WARPS;
} }
} }
...@@ -255,14 +381,12 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -255,14 +381,12 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
h_left = [0,5],h_right = [4,8] h_left = [0,5],h_right = [4,8]
*/ */
NeighborSampleResult* result = new NeighborSampleResult(sample_size, len); NeighborSampleResult* result = new NeighborSampleResult(sample_size, len);
if (len == 0) { if (len == 0) {
return result; return result;
} }
cudaMalloc((void**)&result->val, len * sample_size * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
int* actual_sample_size = result->actual_sample_size;
int64_t* val = result->val;
int total_gpu = resource_->total_gpu(); int total_gpu = resource_->total_gpu();
int dev_id = resource_->dev_id(gpu_id); int dev_id = resource_->dev_id(gpu_id);
platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDAPlace place = platform::CUDAPlace(dev_id);
...@@ -287,11 +411,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -287,11 +411,6 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t)); auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr()); int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
...@@ -331,6 +450,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -331,6 +450,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
of alloc_mem_i, actual_sample_size_of_x equals ((int of alloc_mem_i, actual_sample_size_of_x equals ((int
*)alloc_mem_i)[shard_len + x] *)alloc_mem_i)[shard_len + x]
*/ */
create_storage(gpu_id, i, shard_len * sizeof(int64_t), create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t)); shard_len * (1 + sample_size) * sizeof(int64_t));
} }
...@@ -351,6 +471,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -351,6 +471,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
h_right[i] - h_left[i] + 1, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id)); resource_->remote_stream(i, gpu_id));
} }
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
continue; continue;
...@@ -364,10 +485,42 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -364,10 +485,42 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int* res_array = reinterpret_cast<int*>(node.val_storage); int* res_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = res_array + shard_len; int* actual_size_array = res_array + shard_len;
int64_t* sample_array = (int64_t*)(res_array + shard_len * 2); int64_t* sample_array = (int64_t*)(res_array + shard_len * 2);
neighbor_sample_example<<<grid_size, block_size_, 0,
resource_->remote_stream(i, gpu_id)>>>( // 1. get actual_size_array.
graph, res_array, actual_size_array, sample_array, sample_size, // 2. get sum of actual_size.
shard_len); // 3. get offset ptr
thrust::device_vector<int> t_res_array(shard_len);
thrust::copy(res_array, res_array + shard_len, t_res_array.begin());
thrust::device_vector<int> t_actual_size_array(shard_len);
thrust::transform(t_res_array.begin(), t_res_array.end(),
t_actual_size_array.begin(), DegreeFunctor(graph));
if (sample_size >= 0) {
thrust::transform(t_actual_size_array.begin(), t_actual_size_array.end(),
t_actual_size_array.begin(), MaxFunctor(sample_size));
}
thrust::copy(t_actual_size_array.begin(), t_actual_size_array.end(),
actual_size_array);
int total_sample_sum =
thrust::reduce(t_actual_size_array.begin(), t_actual_size_array.end());
thrust::device_vector<int> output_idx(total_sample_sum);
thrust::device_vector<int> output_offset(shard_len);
thrust::exclusive_scan(t_actual_size_array.begin(),
t_actual_size_array.end(), output_offset.begin(), 0);
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block_(WARP_SIZE, BLOCK_WARPS);
const dim3 grid_((shard_len + TILE_SIZE - 1) / TILE_SIZE);
neighbor_sample<
BLOCK_WARPS,
TILE_SIZE><<<grid_, block_, 0, resource_->remote_stream(i, gpu_id)>>>(
0, graph, sample_size, res_array, shard_len, sample_array,
thrust::raw_pointer_cast(output_idx.data()),
thrust::raw_pointer_cast(output_offset.data()));
} }
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
...@@ -378,13 +531,56 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -378,13 +531,56 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
tables_[i]->rwlock_->UNLock(); tables_[i]->rwlock_->UNLock();
} }
// walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); // walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>( auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, int* d_shard_actual_sample_size_ptr =
d_idx_ptr, sample_size, len); reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
// Store total sample number of each gpu.
thrust::host_vector<int> d_shard_total_sample_size(total_gpu, 0);
move_neighbor_sample_size_to_source_gpu(
gpu_id, total_gpu, h_left, h_right, d_shard_actual_sample_size_ptr,
thrust::raw_pointer_cast(d_shard_total_sample_size.data()));
int allocate_sample_num = 0;
for (int i = 0; i < total_gpu; ++i) {
allocate_sample_num += d_shard_total_sample_size[i];
}
auto d_shard_vals =
memory::Alloc(place, allocate_sample_num * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, h_left, h_right,
d_shard_vals_ptr,
d_shard_total_sample_size);
cudaMalloc((void**)&result->val, allocate_sample_num * sizeof(int64_t));
cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int));
cudaMalloc((void**)&result->offset, len * sizeof(int));
int64_t* val = result->val;
int* actual_sample_size = result->actual_sample_size;
int* offset = result->offset;
fill_dvalues_actual_sample_size<<<grid_size, block_size_, 0, stream>>>(
d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, len);
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size, actual_sample_size + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), offset, 0);
int* d_offset;
cudaMalloc(&d_offset, len * sizeof(int));
thrust::copy(d_shard_actual_sample_size_ptr,
d_shard_actual_sample_size_ptr + len,
t_actual_sample_size.begin());
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(), d_offset, 0);
constexpr int BLOCK_WARPS_ = 128 / WARP_SIZE;
constexpr int TILE_SIZE_ = BLOCK_WARPS_ * 16;
const dim3 block__(WARP_SIZE, BLOCK_WARPS_);
const dim3 grid__((len + TILE_SIZE_ - 1) / TILE_SIZE_);
fill_dvalues_sample_result<BLOCK_WARPS_,
TILE_SIZE_><<<grid__, block__, 0, stream>>>(
d_shard_vals_ptr, val, actual_sample_size, d_idx_ptr, offset, d_offset,
len);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
...@@ -393,6 +589,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -393,6 +589,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
} }
destroy_storage(gpu_id, i); destroy_storage(gpu_id, i);
} }
cudaFree(d_offset);
return result; return result;
} }
......
...@@ -94,19 +94,44 @@ TEST(TEST_FLEET, graph_comm) { ...@@ -94,19 +94,44 @@ TEST(TEST_FLEET, graph_comm) {
0 --index--->0 0 --index--->0
7 --index-->2 7 --index-->2
*/ */
int64_t cpu_key[3] = {7, 0, 6}; int64_t cpu_key[3] = {7, 0, 6};
void *key; void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t)); cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3); auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
res = new int64_t[9]; res = new int64_t[7];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost); cudaMemcpy(res, neighbor_sample_res->val, 56, cudaMemcpyDeviceToHost);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23}; int *actual_sample_size = new int[3];
for (int i = 0; i < 9; i++) { cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, 12,
if (expected_sample_val[i] != -1) { cudaMemcpyDeviceToHost); // 3, 1, 3
ASSERT_EQ(res[i], expected_sample_val[i]); int *cumsum_sample_size = new int[3];
cudaMemcpy(cumsum_sample_size, neighbor_sample_res->offset, 12,
cudaMemcpyDeviceToHost); // 0, 3, 4
std::vector<std::vector<int64_t>> neighbors_;
std::vector<int64_t> neighbors_7 = {28, 29, 30, 31, 32, 33, 34, 35};
std::vector<int64_t> neighbors_0 = {0};
std::vector<int64_t> neighbors_6 = {21, 22, 23, 24, 25, 26, 27};
neighbors_.push_back(neighbors_7);
neighbors_.push_back(neighbors_0);
neighbors_.push_back(neighbors_6);
for (int i = 0; i < 3; i++) {
for (int j = cumsum_sample_size[i];
j < cumsum_sample_size[i] + actual_sample_size[i]; j++) {
bool flag = false;
for (int k = 0; k < neighbors_[i].size(); k++) {
if (res[j] == neighbors_[i][k]) {
flag = true;
break;
}
}
ASSERT_EQ(flag, true);
} }
} }
delete[] res; delete[] res;
delete[] actual_sample_size;
delete[] cumsum_sample_size;
delete neighbor_sample_res; delete neighbor_sample_res;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册