diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index b84529066fd8c4e5b93f2bd44523846ee5487533..698ca3cd35d54017f7b7f2be491853fd6201bdaa 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -1077,6 +1077,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, int cur_degree, int step, int *len_per_row) { + platform::CUDADeviceGuard guard(gpuid_); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id]; uint8_t edge_src_id = node_id >> 32; @@ -2349,6 +2350,7 @@ int GraphDataGenerator::FillWalkBuf() { break; } } + platform::CUDADeviceGuard guard2(gpuid_); buf_state_.Reset(total_row_); int *d_random_row = reinterpret_cast(d_random_row_->ptr()); @@ -2584,6 +2586,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() { break; } } + platform::CUDADeviceGuard guard2(gpuid_); buf_state_.Reset(total_row_); int *d_random_row = reinterpret_cast(d_random_row_->ptr());