未验证 提交 133d63fa 编写于 作者: T Thunderbrook 提交者: GitHub

Fix graph hang (#42768)

* fix device_free

* fix hang
上级 fa8c755a
...@@ -1441,7 +1441,7 @@ std::vector<std::vector<int64_t>> GraphTable::get_all_id(int type_id, int idx, ...@@ -1441,7 +1441,7 @@ std::vector<std::vector<int64_t>> GraphTable::get_all_id(int type_id, int idx,
} }
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get(); auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id); for (auto &id : ids) res[(uint64_t)(id) % slice_num].push_back(id);
} }
return res; return res;
} }
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class GpuPsGraphTable : public HeterComm<int64_t, int64_t, int> { class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
public: public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware) GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
: HeterComm<int64_t, int64_t, int>(1, resource) { : HeterComm<uint64_t, int64_t, int>(1, resource) {
load_factor_ = 0.25; load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t()); rw_lock.reset(new pthread_rwlock_t());
gpu_num = resource_->total_device(); gpu_num = resource_->total_device();
......
...@@ -499,7 +499,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { ...@@ -499,7 +499,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) {
keys.push_back(g.node_list[j].node_id); keys.push_back(g.node_list[j].node_id);
offset.push_back(j); offset.push_back(j);
} }
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8); build_ps(i, (uint64_t*)keys.data(), offset.data(), keys.size(), 1024, 8);
gpu_graph_list[i].node_size = g.node_size; gpu_graph_list[i].node_size = g.node_size;
} else { } else {
build_ps(i, NULL, NULL, 0, 1024, 8); build_ps(i, NULL, NULL, 0, 1024, 8);
...@@ -572,7 +572,8 @@ void GpuPsGraphTable::build_graph_from_cpu( ...@@ -572,7 +572,8 @@ void GpuPsGraphTable::build_graph_from_cpu(
keys.push_back(cpu_graph_list[i].node_list[j].node_id); keys.push_back(cpu_graph_list[i].node_list[j].node_id);
offset.push_back(j); offset.push_back(j);
} }
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8); build_ps(i, (uint64_t*)(keys.data()), offset.data(), keys.size(), 1024,
8);
gpu_graph_list[i].node_size = cpu_graph_list[i].node_size; gpu_graph_list[i].node_size = cpu_graph_list[i].node_size;
} else { } else {
build_ps(i, NULL, NULL, 0, 1024, 8); build_ps(i, NULL, NULL, 0, 1024, 8);
...@@ -665,7 +666,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -665,7 +666,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int* d_shard_actual_sample_size_ptr = int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr()); reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream); stream);
...@@ -708,7 +710,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -708,7 +710,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
sizeof(int) * (shard_len + shard_len % 2)); sizeof(int) * (shard_len + shard_len % 2));
// auto& node = path_[gpu_id][i].nodes_[0]; // auto& node = path_[gpu_id][i].nodes_[0];
} }
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
...@@ -720,7 +723,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, ...@@ -720,7 +723,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
node.in_stream); node.in_stream);
cudaStreamSynchronize(node.in_stream); cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage), tables_[i]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage), reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id)); resource_->remote_stream(i, gpu_id));
...@@ -805,7 +808,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( ...@@ -805,7 +808,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int)); auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr = int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr()); reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream); stream);
...@@ -824,7 +830,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( ...@@ -824,7 +830,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
shard_len * (1 + sample_size) * sizeof(int64_t) + shard_len * (1 + sample_size) * sizeof(int64_t) +
sizeof(int) * (shard_len + shard_len % 2)); sizeof(int) * (shard_len + shard_len % 2));
} }
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);
for (int i = 0; i < total_gpu; ++i) { for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) { if (h_left[i] == -1) {
...@@ -837,7 +844,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( ...@@ -837,7 +844,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
cudaStreamSynchronize(node.in_stream); cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i)); platform::CUDADeviceGuard guard(resource_->dev_id(i));
// If not found, val is -1. // If not found, val is -1.
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage), tables_[i]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage), reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id)); resource_->remote_stream(i, gpu_id));
......
...@@ -320,6 +320,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys, ...@@ -320,6 +320,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
template class HashTable<unsigned long, paddle::framework::FeatureValue>; template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<long, int>; template class HashTable<long, int>;
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
template class HashTable<long, long>; template class HashTable<long, long>;
template class HashTable<long, unsigned long>; template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>; template class HashTable<long, unsigned int>;
...@@ -333,6 +335,8 @@ template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys, ...@@ -333,6 +335,8 @@ template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len, int* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
template void HashTable<unsigned long, int>::get<cudaStream_t>(
const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, unsigned long>::get<cudaStream_t>( template void HashTable<long, unsigned long>::get<cudaStream_t>(
const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys, template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
...@@ -359,6 +363,9 @@ template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys, ...@@ -359,6 +363,9 @@ template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
size_t len, size_t len,
cudaStream_t stream); cudaStream_t stream);
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
const unsigned long* d_keys, const int* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<long, unsigned long>::insert<cudaStream_t>( template void HashTable<long, unsigned long>::insert<cudaStream_t>(
const long* d_keys, const unsigned long* d_vals, size_t len, const long* d_keys, const unsigned long* d_vals, size_t len,
cudaStream_t stream); cudaStream_t stream);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册