未验证 提交 133d63fa 编写于 作者: T Thunderbrook 提交者: GitHub

Fix graph hang (#42768)

* fix device_free

* fix hang
上级 fa8c755a
......@@ -1441,7 +1441,7 @@ std::vector<std::vector<int64_t>> GraphTable::get_all_id(int type_id, int idx,
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
for (auto &id : ids) res[(uint64_t)(id) % slice_num].push_back(id);
}
return res;
}
......
......@@ -23,10 +23,10 @@
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
class GpuPsGraphTable : public HeterComm<int64_t, int64_t, int> {
class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
: HeterComm<int64_t, int64_t, int>(1, resource) {
: HeterComm<uint64_t, int64_t, int>(1, resource) {
load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
gpu_num = resource_->total_device();
......
......@@ -499,7 +499,7 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) {
keys.push_back(g.node_list[j].node_id);
offset.push_back(j);
}
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8);
build_ps(i, (uint64_t*)keys.data(), offset.data(), keys.size(), 1024, 8);
gpu_graph_list[i].node_size = g.node_size;
} else {
build_ps(i, NULL, NULL, 0, 1024, 8);
......@@ -572,7 +572,8 @@ void GpuPsGraphTable::build_graph_from_cpu(
keys.push_back(cpu_graph_list[i].node_list[j].node_id);
offset.push_back(j);
}
build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8);
build_ps(i, (uint64_t*)(keys.data()), offset.data(), keys.size(), 1024,
8);
gpu_graph_list[i].node_size = cpu_graph_list[i].node_size;
} else {
build_ps(i, NULL, NULL, 0, 1024, 8);
......@@ -665,7 +666,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream);
......@@ -708,7 +710,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
sizeof(int) * (shard_len + shard_len % 2));
// auto& node = path_[gpu_id][i].nodes_[0];
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
......@@ -720,7 +723,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
node.in_stream);
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
tables_[i]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
......@@ -805,7 +808,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());
split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);
split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);
heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream);
......@@ -824,7 +830,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
shard_len * (1 + sample_size) * sizeof(int64_t) +
sizeof(int) * (shard_len + shard_len % 2));
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);
walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
......@@ -837,7 +844,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
// If not found, val is -1.
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
tables_[i]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
......
......@@ -320,6 +320,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
template class HashTable<long, int>;
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
template class HashTable<long, long>;
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
......@@ -333,6 +335,8 @@ template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
int* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, int>::get<cudaStream_t>(
const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, unsigned long>::get<cudaStream_t>(
const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
......@@ -359,6 +363,9 @@ template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
size_t len,
cudaStream_t stream);
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
const unsigned long* d_keys, const int* d_vals, size_t len,
cudaStream_t stream);
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
const long* d_keys, const unsigned long* d_vals, size_t len,
cudaStream_t stream);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册