未验证 提交 30cefa2d 编写于 作者: D danleifeng 提交者: GitHub

fix gpugraph cuda error;test=develop (#56133)

上级 425b96a3
...@@ -1077,6 +1077,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, ...@@ -1077,6 +1077,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids,
int cur_degree, int cur_degree,
int step, int step,
int *len_per_row) { int *len_per_row) {
platform::CUDADeviceGuard guard(gpuid_);
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id]; uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id];
uint8_t edge_src_id = node_id >> 32; uint8_t edge_src_id = node_id >> 32;
...@@ -2349,6 +2350,7 @@ int GraphDataGenerator::FillWalkBuf() { ...@@ -2349,6 +2350,7 @@ int GraphDataGenerator::FillWalkBuf() {
break; break;
} }
} }
platform::CUDADeviceGuard guard2(gpuid_);
buf_state_.Reset(total_row_); buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr()); int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());
...@@ -2584,6 +2586,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() { ...@@ -2584,6 +2586,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() {
break; break;
} }
} }
platform::CUDADeviceGuard guard2(gpuid_);
buf_state_.Reset(total_row_); buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr()); int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册