diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index df55fe93be3d84a46ddbff1caa9deb97909aef7d..d7ceb4a18ea19eb733816deef03ef079c2113510 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -430,8 +430,9 @@ int32_t GraphTable::add_comm_edge(int64_t src_id, int64_t dst_id) { return -1; } size_t index = src_shard_id - shard_start; - extra_shards[index]->add_graph_node(src_id)->build_edges(false); - extra_shards[index]->add_neighbor(src_id, dst_id, 1.0); + VLOG(0) << "index add edge " << src_id << " " << dst_id; + shards[index]->add_graph_node(src_id)->build_edges(false); + shards[index]->add_neighbor(src_id, dst_id, 1.0); return 0; } int32_t GraphTable::add_graph_node(std::vector &id_list, diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index edbeac9047997489f77f0d3a28f95f0a3d175a73..70b067b0494f13536c4f84e30fa5a657517707a8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -13,14 +13,15 @@ IF(WITH_GPU) nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) if(WITH_PSCORE) - nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table) + nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table hashtable_kernel) nv_library(graph_sampler SRCS graph_sampler_inl.h DEPS graph_gpu_ps) - #nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps) - #nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps) - #nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS graph_gpu_ps) - # ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu) - # target_link_libraries(test_sample_rate graph_gpu_ps graph_sampler) - # nv_test(test_graph_xx SRCS test_xx.cu DEPS graph_gpu_ps graph_sampler) + + nv_test(test_cpu_query SRCS test_cpu_query.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) + #ADD_EXECUTABLE(test_sample_rate test_sample_rate.cu) + #target_link_libraries(test_sample_rate heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) + #nv_test(test_sample_rate SRCS test_sample_rate.cu DEPS heter_comm table heter_comm_kernel hashtable_kernel heter_ps ${HETERPS_DEPS}) + #ADD_EXECUTABLE(test_cpu_query test_cpu_query.cu) + #target_link_libraries(test_cpu_query graph_gpu_ps) endif() ENDIF() IF(WITH_XPU_KP) diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 27f14e8726d9cac3f7dce5af535785a7353e6d6c..5b8a20f7b9970acb3dbf85f8d7364e81e1b122c8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -16,6 +16,7 @@ #ifdef PADDLE_WITH_HETERPS #include #include +#include #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/cuda_device_guard.h" @@ -41,6 +42,24 @@ struct GpuPsCommGraph { node_list(node_list_), neighbor_size(neighbor_size_), node_size(node_size_) {} + void display_on_cpu() { + VLOG(0) << "neighbor_size = " << neighbor_size; + VLOG(0) << "node_size = " << node_size; + for (int i = 0; i < neighbor_size; i++) { + VLOG(0) << "neighbor " << i << " " << neighbor_list[i]; + } + for (int i = 0; i < node_size; i++) { + VLOG(0) << "node i " << node_list[i].node_id + << " neighbor_size = " << node_list[i].neighbor_size; + std::string str; + int offset = node_list[i].neighbor_offset; + for (int j = 0; j < node_list[i].neighbor_size; j++) { + if (j > 0) str += ","; + str += std::to_string(neighbor_list[j + offset]); + } + VLOG(0) << str; + } + } }; /* diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index cd55d09608f5409ad31e04e94988de0cb930c1cb..4eb42d80a00b51c797b5f1d3822008dc1f4964f7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -18,6 +18,7 @@ #include "heter_comm.h" #include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/platform/enforce.h" #ifdef PADDLE_WITH_HETERPS namespace paddle { @@ -28,10 +29,10 @@ class GpuPsGraphTable : public HeterComm { : HeterComm(1, resource) { load_factor_ = 0.25; rw_lock.reset(new pthread_rwlock_t()); - gpu_num = resource_->total_gpu(); + gpu_num = resource_->total_device(); cpu_table_status = -1; if (topo_aware) { - int total_gpu = resource_->total_gpu(); + int total_gpu = resource_->total_device(); std::map device_map; for (int i = 0; i < total_gpu; i++) { device_map[resource_->dev_id(i)] = i; @@ -62,7 +63,7 @@ class GpuPsGraphTable : public HeterComm { node.key_storage = NULL; node.val_storage = NULL; node.sync = 0; - node.gpu_num = transfer_id; + node.dev_num = transfer_id; } nodes.push_back(Node()); Node &node = nodes.back(); @@ -71,7 +72,7 @@ class GpuPsGraphTable : public HeterComm { node.key_storage = NULL; node.val_storage = NULL; node.sync = 0; - node.gpu_num = j; + node.dev_num = j; } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h index c235378def51f222148e82ba98bfef72e35730f7..37067dc36543c9778503119a49a26960f1ed8246 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h @@ -28,14 +28,16 @@ sample_result is to save the neighbor sampling result, its size is len * sample_size; */ - __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* node_index, int* actual_size, int64_t* res, int sample_len, int* sample_status, int n, int from) { - // printf("%d %d %d\n",blockIdx.x,threadIdx.x,threadIdx.y); int id = blockIdx.x * blockDim.y + threadIdx.y; if (id < n) { + if (node_index[id] == -1) { + actual_size[id] = 0; + return; + } curandState rng; curand_init(blockIdx.x, threadIdx.x, threadIdx.y, &rng); int index = threadIdx.x; @@ -305,7 +307,6 @@ __global__ void fill_dvalues(int64_t* d_shard_vals, int64_t* d_vals, const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i]; - // d_vals[idx[i]] = d_shard_vals[i]; for (int j = 0; j < sample_size; j++) { d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j]; } @@ -351,7 +352,7 @@ void GpuPsGraphTable::build_graph_from_cpu( VLOG(0) << "in build_graph_from_cpu cpu_graph_list size = " << cpu_graph_list.size(); PADDLE_ENFORCE_EQ( - cpu_graph_list.size(), resource_->total_gpu(), + cpu_graph_list.size(), resource_->total_device(), platform::errors::InvalidArgument("the cpu node list size doesn't match " "the number of gpu on your machine.")); clear_graph_info(); @@ -378,6 +379,7 @@ void GpuPsGraphTable::build_graph_from_cpu( build_ps(i, keys.data(), offset.data(), keys.size(), 1024, 8); gpu_graph_list[i].node_size = cpu_graph_list[i].node_size; } else { + build_ps(i, NULL, NULL, 0, 1024, 8); gpu_graph_list[i].node_list = NULL; gpu_graph_list[i].node_size = 0; } @@ -442,7 +444,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, // cudaMalloc((void**)&result->actual_sample_size, len * sizeof(int)); int* actual_sample_size = result->actual_sample_size; int64_t* val = result->val; - int total_gpu = resource_->total_gpu(); + int total_gpu = resource_->total_device(); // int dev_id = resource_->dev_id(gpu_id); auto stream = resource_->local_stream(gpu_id, 0); @@ -472,9 +474,11 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id); - fill_shard_key<<>>(d_shard_keys_ptr, key, - d_idx_ptr, len); - + // fill_shard_key<<>>(d_shard_keys_ptr, + // key, + // d_idx_ptr, len); + heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len, + stream); cudaStreamSynchronize(stream); cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), @@ -510,6 +514,9 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, */ create_storage(gpu_id, i, shard_len * sizeof(int64_t), shard_len * (1 + sample_size) * sizeof(int64_t)); + auto& node = path_[gpu_id][i].nodes_[0]; + cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int), + node.in_stream); } // auto end1 = std::chrono::steady_clock::now(); // auto tt = std::chrono::duration_cast(end1 - @@ -532,7 +539,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, h_right[i] - h_left[i] + 1, resource_->remote_stream(i, gpu_id)); // node.in_stream); - auto shard_len = h_right[i] - h_left[i] + 1; + int shard_len = h_right[i] - h_left[i] + 1; auto graph = gpu_graph_list[i]; int* id_array = reinterpret_cast(node.val_storage); int* actual_size_array = id_array + shard_len; @@ -595,20 +602,13 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id, // auto& node = path_[gpu_id][i].nodes_.back(); // cudaStreamSynchronize(node.in_stream); cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); - // tables_[i]->rwlock_->UNLock(); } - // walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size, h_left, h_right, d_shard_vals_ptr, d_shard_actual_sample_size_ptr); - fill_dvalues<<>>( d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size, d_idx_ptr, sample_size, len); - // cudaStreamSynchronize(stream); - // auto end2 = std::chrono::steady_clock::now(); - // tt = std::chrono::duration_cast(end2 - end1); - // VLOG(0)<< "sample graph time " << tt.count() << " us"; for (int i = 0; i < total_gpu; ++i) { int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; if (shard_len == 0) { diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index cac1b9c17e077f3dd94a1dd405abdd09be355a62..fc54be447fe1719a434a5e8896f903a04dc749ae 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -297,12 +297,17 @@ void HashTable::update(const KeyType* d_keys, } template class HashTable; +template class HashTable; template void HashTable::get< cudaStream_t>(const unsigned long* d_keys, paddle::framework::FeatureValue* d_vals, size_t len, cudaStream_t stream); +template void HashTable::get(const long* d_keys, + int* d_vals, size_t len, + cudaStream_t stream); + // template void // HashTable::get( // const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t @@ -313,6 +318,11 @@ template void HashTable::insert< const paddle::framework::FeatureValue* d_vals, size_t len, cudaStream_t stream); +template void HashTable::insert(const long* d_keys, + const int* d_vals, + size_t len, + cudaStream_t stream); + // template void HashTable::insert< // cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index b5b1c22f304543f638a65e4a8931048e82c790f3..338009250bc4fc4c074bb607e853ac5d601dd4b2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -212,10 +212,10 @@ class HeterComm { std::vector> path_; float load_factor_{0.75}; int block_size_{256}; - int topo_aware_{0}; + std::unique_ptr heter_comm_kernel_; private: - std::unique_ptr heter_comm_kernel_; + int topo_aware_{0}; std::vector storage_; int feanum_{1800 * 2048}; int multi_node_{0}; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 694bdb8d563f5726bfc40509f3e58c8c5553f047..bdeb696a92bcef6592d43d4d3050f6838f6760a6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -218,6 +218,14 @@ template void HeterCommKernel::calc_shard_index< int* shard_index, int total_devs, const cudaStream_t& stream); +template void HeterCommKernel::calc_shard_index( + long* d_keys, long long len, int* shard_index, int total_devs, + const cudaStream_t& stream); + +template void HeterCommKernel::fill_shard_key( + long* d_shard_keys, long* d_keys, int* idx, long long len, + const cudaStream_t& stream); + template void HeterCommKernel::fill_shard_key( unsigned long* d_shard_keys, unsigned long* d_keys, int* idx, long long len, const cudaStream_t& stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu new file mode 100644 index 0000000000000000000000000000000000000000..d812542f17ba0d1428a1c67f44bbe232127f783f --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h" +#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" +#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" +#include "paddle/fluid/platform/cuda_device_guard.h" + +using namespace paddle::framework; +namespace platform = paddle::platform; +// paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph +// paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( +// std::vector ids) +TEST(TEST_FLEET, test_cpu_cache) { + int gpu_num = 0; + int st = 0, u = 0; + std::vector device_id_mapping; + for (int i = 0; i < 2; i++) device_id_mapping.push_back(i); + gpu_num = device_id_mapping.size(); + ::paddle::distributed::GraphParameter table_proto; + table_proto.set_shard_num(24); + std::shared_ptr resource = + std::make_shared(device_id_mapping); + resource->enable_p2p(); + int use_nv = 1; + GpuPsGraphTable g(resource, use_nv); + g.init_cpu_table(table_proto); + std::vector vec; + int n = 10; + std::vector ids0, ids1; + for (int i = 0; i < n; i++) { + g.cpu_graph_table->add_comm_edge(i, (i + 1) % n); + g.cpu_graph_table->add_comm_edge(i, (i - 1 + n) % n); + if (i % 2 == 0) ids0.push_back(i); + } + ids1.push_back(5); + vec.push_back(g.cpu_graph_table->make_gpu_ps_graph(ids0)); + vec.push_back(g.cpu_graph_table->make_gpu_ps_graph(ids1)); + vec[0].display_on_cpu(); + vec[1].display_on_cpu(); + g.build_graph_from_cpu(vec); + int64_t cpu_key[3] = {0, 1, 2}; + void *key; + platform::CUDADeviceGuard guard(0); + cudaMalloc((void **)&key, 3 * sizeof(int64_t)); + cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice); + auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 2, 3); + int64_t *res = new int64_t[7]; + cudaMemcpy(res, neighbor_sample_res->val, 3 * 2 * sizeof(int64_t), + cudaMemcpyDeviceToHost); + int *actual_sample_size = new int[3]; + cudaMemcpy(actual_sample_size, neighbor_sample_res->actual_sample_size, + 3 * sizeof(int), + cudaMemcpyDeviceToHost); // 3, 1, 3 + + //{0,9} or {9,0} is expected for key 0 + //{0,2} or {2,0} is expected for key 1 + //{1,3} or {3,1} is expected for key 2 + for (int i = 0; i < 3; i++) { + VLOG(0) << "actual sample size for " << i << " is " + << actual_sample_size[i]; + for (int j = 0; j < actual_sample_size[i]; j++) { + VLOG(0) << "sampled an neighbor for node" << i << " : " << res[i * 2 + j]; + } + } +} diff --git a/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu b/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu index 887bda4be4a893ba7c9c04fb5d01ee80c0a56760..07e561fb3b050628babf9b20eebf0b24e3bfe484 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu @@ -86,6 +86,7 @@ void testSampleRate() { int start = 0; pthread_rwlock_t rwlock; pthread_rwlock_init(&rwlock, NULL); + { ::paddle::distributed::GraphParameter table_proto; // table_proto.set_gpups_mode(false); @@ -93,9 +94,9 @@ void testSampleRate() { table_proto.set_task_pool_size(24); std::cerr << "initializing begin"; distributed::GraphTable graph_table; - graph_table.initialize(table_proto); + graph_table.Initialize(table_proto); std::cerr << "initializing done"; - graph_table.load(input_file, std::string("e>")); + graph_table.Load(input_file, std::string("e>")); int sample_actual_size = -1; int step = fixed_key_size, cur = 0; while (sample_actual_size != 0) {