diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 9da23ee29d7fd699f535394e2c8597553aa51a4a..3862b23e2d5569964e9610146099454257f5b423 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -547,6 +547,7 @@ class PSGPUWorker : public HogwildWorker { virtual ~PSGPUWorker() {} virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); + virtual void TrainFilesWithProfiler(); virtual void SetNeedDump(bool need_dump_field); virtual void SetChannelWriter(ChannelObject* queue); virtual void SetWorkerNum(int num) { worker_num_ = num; } @@ -556,7 +557,6 @@ class PSGPUWorker : public HogwildWorker { virtual void ProduceTasks() override; virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const gpuEvent_t event) { event_ = event; } - virtual void TrainFilesWithProfiler() {} void ResetStat(); protected: @@ -618,6 +618,7 @@ class PSGPUWorker : public HogwildWorker { gpuStream_t copy_stream_; int batch_cnt_{0}; std::atomic done_cnt_{0}; + platform::DeviceContext* dev_ctx_ = nullptr; double total_time_; double read_time_; diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index fc987b523d559a2559050602c4b8e98692804c1c..a02931b3f5c28a6e8e09866f3352109b7fe91adb 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -66,18 +66,19 @@ class HeterContext { mutex_[i] = new std::mutex(); } } - void batch_add_keys(const std::vector>& thread_keys) { + void batch_add_keys( + const std::vector>& thread_keys) { assert(thread_keys.size() == feature_keys_.size()); for (uint32_t i = 0; i < shard_num_; i++) { int idx = 0; idx = feature_keys_[i].size(); feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size()); - for (uint64_t j = 0; j < thread_keys[i].size(); j++) { - feature_keys_[i][idx + j] = thread_keys[i][j]; - } + std::copy(thread_keys[i].begin(), thread_keys[i].end(), + feature_keys_[i].begin() + idx); } } + void UniqueKeys() { std::vector threads; auto unique_func = [this](int i) { diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 2aa00e84e1599bfa09b013dfb00bbda1299fe9e6..e5c0972763bede000961e970390c64431ac3cb22 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -55,6 +55,8 @@ class HashTable { void update(const KeyType* d_keys, const GradType* d_grads, size_t len, Sgd sgd, gpuStream_t stream); + int size() { return container_->size(); } + private: TableContainer* container_; int BLOCK_SIZE_{256}; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 77591c6df22a5767f9c5dd3c58a01ad798a6e5d1..0e38ebbd7f4e7280d5571ffd216143b277f063d6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -118,6 +118,12 @@ class HeterComm { std::vector nodes_; }; + struct CopyTask { + Path* path; + int step; + CopyTask(Path* path_, int step_) : path(path_), step(step_) {} + }; + struct LocalStorage { LocalStorage() {} void init(int size, int dev_id) { @@ -160,9 +166,10 @@ class HeterComm { void create_storage( int start_index, int end_index, int keylen, int vallen, std::vector>& local_strorage); - void walk_to_src(int start_index, int end_index, char* src_val); - void walk_to_dest(int start_index, int end_index, char* src_key, - char* src_val); + void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, + KeyType* src_key, GradType* src_val); + void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, + ValType* src_val); private: using Table = HashTable; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 4e4563daa19faa0c7b9c484454fdb7e0c53d53a1..2f1c809c01eaadcad8c3406e882acafaadb09134 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include + #ifdef PADDLE_WITH_PSLIB namespace paddle { namespace framework { @@ -182,53 +184,105 @@ void HeterComm::create_storage( } template -void HeterComm::walk_to_dest(int start_index, - int end_index, - char* src_key, - char* src_val) { +void HeterComm::walk_to_dest( + int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, + GradType* src_val) { int need_copy_val = 0; if (src_val) { need_copy_val = 1; } - auto& nodes = path_[start_index][end_index].nodes_; - for (size_t i = 0; i < nodes.size(); ++i) { - cudaMemcpyAsync(nodes[i].key_storage, src_key, nodes[i].key_bytes_len, - cudaMemcpyDefault, nodes[i].in_stream); + std::queue que; + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + int size = path_[start_index][i].nodes_.size(); + auto& node = path_[start_index][i].nodes_[0]; + CopyTask t(&path_[start_index][i], 0); + que.push(t); + cudaMemcpyAsync(node.key_storage, + reinterpret_cast(src_key + h_left[i]), + node.key_bytes_len, cudaMemcpyDefault, node.in_stream); if (need_copy_val) { - cudaMemcpyAsync(nodes[i].val_storage, src_val, nodes[i].val_bytes_len, - cudaMemcpyDefault, nodes[i].in_stream); + cudaMemcpyAsync(node.val_storage, + reinterpret_cast(src_val + h_left[i]), + node.val_bytes_len, cudaMemcpyDefault, node.in_stream); + } + } + while (!que.empty()) { + CopyTask& cur_task = que.front(); + que.pop(); + if (cur_task.path->nodes_[cur_task.step].sync) { + cudaStreamSynchronize(cur_task.path->nodes_[cur_task.step].in_stream); } - if (nodes[i].sync) { - cudaStreamSynchronize(nodes[i].in_stream); + if (cur_task.step != cur_task.path->nodes_.size() - 1) { + int cur_step = cur_task.step; + CopyTask c(cur_task.path, cur_step + 1); + que.push(c); + cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].key_storage, + cur_task.path->nodes_[cur_step].key_storage, + cur_task.path->nodes_[cur_step + 1].key_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step + 1].in_stream); + if (need_copy_val) { + cudaMemcpyAsync(cur_task.path->nodes_[cur_step + 1].val_storage, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step + 1].val_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step + 1].in_stream); + } } - // cudaStreamSynchronize(nodes[i].in_stream); - src_key = nodes[i].key_storage; - src_val = nodes[i].val_storage; } } template -void HeterComm::walk_to_src(int start_index, - int end_index, - char* src_val) { - auto& nodes = path_[start_index][end_index].nodes_; - int len = nodes.size(); - char* start = NULL; - for (int i = len - 1; i >= 0; --i) { - if (start == NULL) { - start = nodes[i].val_storage; +void HeterComm::walk_to_src( + int start_index, int gpu_num, int* h_left, int* h_right, ValType* src_val) { + std::queue que; + for (int i = 0; i < gpu_num; i++) { + if (h_left[i] == -1 || h_right[i] == -1) { continue; } - cudaMemcpyAsync(nodes[i].val_storage, start, nodes[i].val_bytes_len, - cudaMemcpyDefault, nodes[i].out_stream); - if (nodes[i].sync) { - cudaStreamSynchronize(nodes[i].out_stream); + int cur_step = path_[start_index][i].nodes_.size() - 1; + auto& node = path_[start_index][i].nodes_[cur_step]; + if (cur_step == 0) { + cudaMemcpyAsync(reinterpret_cast(src_val + h_left[i]), + node.val_storage, node.val_bytes_len, cudaMemcpyDefault, + node.out_stream); + } else { + CopyTask t(&path_[start_index][i], cur_step - 1); + que.push(t); + cudaMemcpyAsync(path_[start_index][i].nodes_[cur_step - 1].val_storage, + node.val_storage, + path_[start_index][i].nodes_[cur_step - 1].val_bytes_len, + cudaMemcpyDefault, + path_[start_index][i].nodes_[cur_step - 1].out_stream); + } + } + while (!que.empty()) { + CopyTask& cur_task = que.front(); + que.pop(); + int cur_step = cur_task.step; + if (cur_task.path->nodes_[cur_step].sync) { + cudaStreamSynchronize(cur_task.path->nodes_[cur_step].out_stream); + } + if (cur_step > 0) { + CopyTask c(cur_task.path, cur_step - 1); + que.push(c); + cudaMemcpyAsync(cur_task.path->nodes_[cur_step - 1].val_storage, + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step - 1].val_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step - 1].out_stream); + } else if (cur_step == 0) { + int end_index = cur_task.path->nodes_.back().gpu_num; + cudaMemcpyAsync(reinterpret_cast(src_val + h_left[end_index]), + cur_task.path->nodes_[cur_step].val_storage, + cur_task.path->nodes_[cur_step].val_bytes_len, + cudaMemcpyDefault, + cur_task.path->nodes_[cur_step].out_stream); } - start = nodes[i].val_storage; } - cudaMemcpyAsync(src_val, nodes[0].val_storage, nodes[0].val_bytes_len, - cudaMemcpyDefault, nodes[0].out_stream); - // cudaStreamSynchronize(nodes[0].out_stream); } template @@ -462,14 +516,7 @@ void HeterComm::pull_sparse(int num, shard_len * sizeof(ValType), local_storage); } - for (int i = 0; i < total_gpu; ++i) { - int shard_len = h_right[i] - h_left[i] + 1; - if (h_left[i] == -1 || h_right[i] == -1) { - continue; - } - walk_to_dest(num, i, reinterpret_cast(d_shard_keys_ptr + h_left[i]), - NULL); - } + walk_to_dest(num, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1) { @@ -486,14 +533,7 @@ void HeterComm::pull_sparse(int num, cudaStreamSynchronize(resource_->remote_stream(i)); } - for (int i = 0; i < total_gpu; ++i) { - int shard_len = h_right[i] - h_left[i] + 1; - if (h_left[i] == -1 || h_right[i] == -1) { - continue; - } - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - walk_to_src(num, i, reinterpret_cast(d_shard_vals_ptr + h_left[i])); - } + walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr); for (int i = 0; i < total_gpu; ++i) { auto& node = path_[num][i].nodes_.front(); @@ -561,7 +601,6 @@ void HeterComm::push_sparse(int gpu_num, cudaMemcpyDeviceToHost); std::vector> local_storage; - for (int i = 0; i < total_gpu; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { @@ -571,15 +610,8 @@ void HeterComm::push_sparse(int gpu_num, shard_len * sizeof(GradType), local_storage); } - for (int i = 0; i < total_gpu; ++i) { - int shard_len = h_right[i] - h_left[i] + 1; - if (h_left[i] == -1 || h_right[i] == -1) { - continue; - } - walk_to_dest(gpu_num, i, - reinterpret_cast(d_shard_keys_ptr + h_left[i]), - reinterpret_cast(d_shard_grads_ptr + h_left[i])); - } + walk_to_dest(gpu_num, total_gpu, h_left, h_right, d_shard_keys_ptr, + d_shard_grads_ptr); for (int i = 0; i < total_gpu; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 516f09a9ef26e50dfe7ce5a00855b3dae383ff46..4274876c9975e5b16824af4c799bf81228659d92 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -65,9 +65,6 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, for (int i = 0; i < thread_keys_thread_num_; i++) { thread_keys_[i].resize(thread_keys_shard_num_); for (int j = 0; j < thread_keys_shard_num_; j++) { - thread_keys_[i][j].reserve(2 * max_fea_num_per_pass_ / - thread_keys_shard_num_ / - thread_keys_thread_num_); } } const std::deque& vec_data = input_channel->GetData(); @@ -84,7 +81,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, for (const auto feasign : feasign_v) { uint64_t cur_key = feasign.sign().uint64_feasign_; int shard_id = cur_key % thread_keys_shard_num_; - this->thread_keys_[i][shard_id].push_back(cur_key); + this->thread_keys_[i][shard_id].insert(cur_key); } } }; @@ -123,7 +120,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size(); local_ptr[i].resize(local_keys[i].size()); } - + timeline.Start(); auto ptl_func = [this, &local_keys, &local_ptr, &table_id, &fleet_ptr](int i) { size_t key_size = local_keys[i].size(); @@ -149,7 +146,8 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, t.join(); } timeline.Pause(); - VLOG(1) << "GpuPs pull sparse cost " << timeline.ElapsedSec() << " seconds."; + VLOG(1) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec() + << " seconds."; timeline.Start(); auto build_func = [device_num, &local_keys, &local_ptr, &device_keys, @@ -225,6 +223,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { size_t size_max = 0; for (int i = 0; i < device_num; i++) { feature_keys_count[i] = gpu_task->device_keys_[i].size(); + VLOG(1) << i << " card contains feasign nums: " << feature_keys_count[i]; size_max = std::max(size_max, feature_keys_count[i]); } if (HeterPs_) { @@ -314,7 +313,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, "GpuPs: PullSparse Only Support CUDAPlace Now.")); } all_timer.Pause(); - VLOG(1) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() + VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec() << " s"; VLOG(3) << "End PullSparse"; @@ -360,7 +359,7 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, "GPUPS: PushSparseGrad Only Support CUDAPlace Now.")); } all_timer.Pause(); - VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec() + VLOG(3) << "PushSparseGrad total cost: " << all_timer.ElapsedSec() << " s, of which GPUPS cost: " << push_gpups_timer.ElapsedSec() << " s"; VLOG(3) << "End PushSparseGrad"; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index fd3323d9d4764e2cdd22f8a7bc6699fa328968dd..ef586b41fe05d2f21e1469dcd7bcce3d77fc9651 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -227,7 +227,7 @@ class PSGPUWrapper { std::vector heter_devices_; std::unordered_set gpu_ps_config_keys_; HeterObjectPool gpu_task_pool_; - std::vector>> thread_keys_; + std::vector>> thread_keys_; int thread_keys_thread_num_ = 37; int thread_keys_shard_num_ = 37; uint64_t max_fea_num_per_pass_ = 5000000000; diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc index 962f666478cf0aa7418be11c1f17391038db531d..e77932fa5f226518f7be4177488d6cc55f2fce06 100644 --- a/paddle/fluid/framework/ps_gpu_trainer.cc +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -131,8 +131,13 @@ void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) { void PSGPUTrainer::Run() { for (size_t thidx = 0; thidx < places_.size(); ++thidx) { - threads_.push_back( - std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); + if (!debug_) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); + } else { + threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler, + workers_[thidx].get())); + } } } diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index 1540679e00c97dc9384d754ff915d82aa621ca5b..2597901d91f36bab7e6a1e3553d8c43bb7a686f8 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -33,6 +33,7 @@ namespace framework { void PSGPUWorker::Initialize(const TrainerDesc& desc) { param_ = desc.downpour_param(); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); mpi_rank_ = desc.mpi_rank(); trainer_desc_ = desc; /* @@ -177,6 +178,81 @@ void PSGPUWorker::TrainFiles() { return; } +void PSGPUWorker::TrainFilesWithProfiler() { + platform::SetNumThreads(1); + VLOG(1) << "Begin to train files with profiler"; + device_reader_->Start(); + std::vector op_total_time; + std::vector op_name; + for (auto& op : ops_) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op_name.push_back(op->Type()); + } + } + + VLOG(3) << "op name size: " << op_name.size(); + op_total_time.resize(op_name.size()); + for (size_t i = 0; i < op_total_time.size(); ++i) { + op_total_time[i] = 0.0; + } + platform::Timer timeline; + double total_time = 0.0; + double read_time = 0.0; + int total_ins_num = 0; + int cur_batch; + timeline.Start(); + while ((cur_batch = device_reader_->Next()) > 0) { + total_ins_num += cur_batch; + timeline.Pause(); + read_time += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + + int run_op_idx = 0; + dev_ctx_->Wait(); + for (auto& op : ops_) { + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + timeline.Start(); + VLOG(3) << "Going to run op " << op_name[run_op_idx]; + op->Run(*thread_scope_, place_); + dev_ctx_->Wait(); + VLOG(3) << "Op " << op_name[run_op_idx] << " Finished"; + timeline.Pause(); + op_total_time[run_op_idx++] += timeline.ElapsedSec(); + total_time += timeline.ElapsedSec(); + } + } + timeline.Start(); + PrintFetchVars(); + thread_scope_->DropKids(); + dev_ctx_->Wait(); + timeline.Pause(); + total_time += timeline.ElapsedSec(); + timeline.Start(); + } + VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " << total_time + << " seconds, ins_num: " << total_ins_num; + for (size_t i = 0; i < op_name.size(); ++i) { + VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i] + << ", mean time: " << op_total_time[i] / total_ins_num + << "s, totol time:" << op_total_time[i] << "sec"; + } + return; +} + void PSGPUWorker::ResetStat() { total_time_ = 0; read_time_ = 0; diff --git a/python/paddle/distributed/fleet/utils/fs.py b/python/paddle/distributed/fleet/utils/fs.py index 221f09a796a6f388429fb75367c330abffb4c8b0..7e62e551fe8d53eb354505bdc86d7c12d9d06726 100644 --- a/python/paddle/distributed/fleet/utils/fs.py +++ b/python/paddle/distributed/fleet/utils/fs.py @@ -447,9 +447,6 @@ class HDFSClient(FS): configs, time_out=5 * 60 * 1000, # ms sleep_inter=1000): # ms - # Raise exception if JAVA_HOME not exists. - java_home = os.environ["JAVA_HOME"] - self.pre_commands = [] hadoop_bin = '%s/bin/hadoop' % hadoop_home self.pre_commands.append(hadoop_bin)