diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index b53044b7493e0449df1b6fa75eb2f2151f6749fe..88f0211160003bb6ec05f48a54abd8a23cefaa85 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -1441,7 +1441,7 @@ std::vector> GraphTable::get_all_id(int type_id, int idx, } for (size_t i = 0; i < tasks.size(); i++) { auto ids = tasks[i].get(); - for (auto &id : ids) res[id % slice_num].push_back(id); + for (auto &id : ids) res[(uint64_t)(id) % slice_num].push_back(id); } return res; } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 9e7ee80edcd0c175dc1a401d6a16d8fc96311bc6..ae57c2ebe932f85d0559c18800a0f2e869f3d210 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -23,10 +23,10 @@ #ifdef PADDLE_WITH_HETERPS namespace paddle { namespace framework { -class GpuPsGraphTable : public HeterComm { +class GpuPsGraphTable : public HeterComm { public: GpuPsGraphTable(std::shared_ptr resource, int topo_aware) - : HeterComm(1, resource) { + : HeterComm(1, resource) { load_factor_ = 0.25; rw_lock.reset(new pthread_rwlock_t()); gpu_num = resource_->total_device(); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 631ca962fae9c12a0fbae2dd1d2d67fa031f6a6e..72b9cae41c0fdfb2807ffc3d90bc3bca1377b059 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -499,7 +499,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { keys.push_back(g.node_list[j].node_id); offset.push_back(j); } - build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8); + build_ps(i, (uint64_t*)keys.data(), offset.data(), keys.size(), 1024, 8); gpu_graph_list[i].node_size = g.node_size; } else { build_ps(i, NULL, NULL, 0, 1024, 8); @@ -572,7 +572,8 @@ void GpuPsGraphTable::build_graph_from_cpu( keys.push_back(cpu_graph_list[i].node_list[j].node_id); offset.push_back(j); } - build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8); + build_ps(i, (uint64_t*)(keys.data()), offset.data(), keys.size(), 1024, + 8); gpu_graph_list[i].node_size = cpu_graph_list[i].node_size; } else { build_ps(i, NULL, NULL, 0, 1024, 8); @@ -665,7 +666,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int* d_shard_actual_sample_size_ptr = reinterpret_cast(d_shard_actual_sample_size->ptr()); - split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr, + d_right_ptr, gpu_id); heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, stream); @@ -708,7 +710,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, sizeof(int) * (shard_len + shard_len % 2)); // auto& node = path_[gpu_id][i].nodes_[0]; } - walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); + walk_to_dest(gpu_id, total_gpu, h_left, h_right, + (uint64_t*)(d_shard_keys_ptr), NULL); for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { @@ -720,7 +723,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, node.in_stream); cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); - tables_[i]->get(reinterpret_cast(node.key_storage), + tables_[i]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, resource_->remote_stream(i, gpu_id)); @@ -805,7 +808,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); int* d_shard_actual_sample_size_ptr = reinterpret_cast(d_shard_actual_sample_size->ptr()); - split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); + + split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr, + d_right_ptr, gpu_id); + heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, stream); @@ -824,7 +830,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( shard_len * (1 + sample_size) * sizeof(int64_t) + sizeof(int) * (shard_len + shard_len % 2)); } - walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); + walk_to_dest(gpu_id, total_gpu, h_left, h_right, + (uint64_t*)(d_shard_keys_ptr), NULL); for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { @@ -837,7 +844,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); // If not found, val is -1. - tables_[i]->get(reinterpret_cast(node.key_storage), + tables_[i]->get(reinterpret_cast(node.key_storage), reinterpret_cast(node.val_storage), h_right[i] - h_left[i] + 1, resource_->remote_stream(i, gpu_id)); diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 5a29159aa12a83a2d44d820f0e50f85106fe31cf..5edc218796ef8a3c3052d3aec9cad1c101f67191 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -320,6 +320,8 @@ void HashTable::update(const KeyType* d_keys, template class HashTable; template class HashTable; +template class HashTable; +template class HashTable; template class HashTable; template class HashTable; template class HashTable; @@ -333,6 +335,8 @@ template void HashTable::get(const long* d_keys, int* d_vals, size_t len, cudaStream_t stream); +template void HashTable::get( + const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get(const long* d_keys, @@ -359,6 +363,9 @@ template void HashTable::insert(const long* d_keys, size_t len, cudaStream_t stream); +template void HashTable::insert( + const unsigned long* d_keys, const int* d_vals, size_t len, + cudaStream_t stream); template void HashTable::insert( const long* d_keys, const unsigned long* d_vals, size_t len, cudaStream_t stream);