From 30cefa2de430142e7510706cea3f188aba98cc93 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Tue, 15 Aug 2023 14:42:45 +0800 Subject: [PATCH] fix gpugraph cuda error;test=develop (#56133) --- paddle/fluid/framework/data_feed.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index b84529066fd..698ca3cd35d 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -1077,6 +1077,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, int cur_degree, int step, int *len_per_row) { + platform::CUDADeviceGuard guard(gpuid_); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id]; uint8_t edge_src_id = node_id >> 32; @@ -2349,6 +2350,7 @@ int GraphDataGenerator::FillWalkBuf() { break; } } + platform::CUDADeviceGuard guard2(gpuid_); buf_state_.Reset(total_row_); int *d_random_row = reinterpret_cast(d_random_row_->ptr()); @@ -2584,6 +2586,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() { break; } } + platform::CUDADeviceGuard guard2(gpuid_); buf_state_.Reset(total_row_); int *d_random_row = reinterpret_cast(d_random_row_->ptr()); -- GitLab