diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 04befbe1ca01d4bfec5872a63565f21d110a6c67..3336b5783a8cf14358db5b74ae799ba470a640fa 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -455,6 +455,7 @@ class SectionWorker : public DeviceWorker { std::vector> ops_; static std::mutex thread_mutex; + static std::mutex cout_mutex; static std::condition_variable thread_condition; static bool threads_completed; std::shared_ptr program_; diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 758b728fd9cffff6867a46a6c22c86e496103b84..b827435508fa51c734b2f8fbebb86aa74469fdfe 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -13,6 +13,7 @@ // limitations under the License. #if defined(PADDLE_WITH_NCCL) +#include #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" @@ -44,7 +45,6 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, "must be 1 now, but the value you give is %d.", num_readers)); auto* reader = readers[0]; - feed_var_names_ = reader->GetUseSlotAlias(); workers_.resize(section_num_); for (int i = 0; i < section_num_; ++i) { @@ -123,26 +123,36 @@ void PipelineTrainer::CopyParameters(int section_id, int microbatch_id, const ProgramDesc& program, const platform::Place& place) { auto& global_block = program.Block(0); + std::map param_map; for (auto& var : global_block.AllVars()) { - int is_feed_var = - std::count(feed_var_names_.begin(), feed_var_names_.end(), var->Name()); - if ((var->Persistable() || is_feed_var) && microbatch_id == 0) { - if (is_feed_var) { - auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name()); - VLOG(3) << "data name: " << var->Name() << ", ptr: " << new_ptr; - InitializeVariable(new_ptr, var->GetType()); - } else { - auto* ptr = root_scope_->FindVar(var->Name()); - auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name()); - VLOG(3) << "Create persistable var " << var->Name() << " for minibatch " - << section_id << ", which pointer is " << new_ptr; - InitializeVariable(new_ptr, var->GetType()); - const LoDTensor& root_tensor = ptr->Get(); - LoDTensor* minibatch_tensor = new_ptr->GetMutable(); - TensorCopy(*static_cast(&root_tensor), place, - static_cast(minibatch_tensor)); + if (var->Persistable()) { + param_map[var->Name()] = 1; + } + } + for (auto& var : global_block.AllVars()) { + bool is_param_grad = false; + size_t pos = 0; + if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) { + auto prefix_name = var->Name().substr(0, pos); + if (param_map.find(prefix_name) != param_map.end()) { + is_param_grad = true; + } + } + VLOG(3) << "Var name: " << var->Name(); + if ((var->Persistable() || is_param_grad) && microbatch_id == 0) { + auto* ptr = root_scope_->FindVar(var->Name()); + auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name()); + VLOG(3) << "Create persistable var " << var->Name() << " for minibatch " + << section_id << ", which pointer is " << new_ptr; + InitializeVariable(new_ptr, var->GetType()); + if (is_param_grad) { + continue; } - } else if (!var->Persistable() && !is_feed_var) { + const LoDTensor& root_tensor = ptr->Get(); + LoDTensor* minibatch_tensor = new_ptr->GetMutable(); + TensorCopy(*static_cast(&root_tensor), place, + static_cast(minibatch_tensor)); + } else if (!var->Persistable() && !is_param_grad) { auto* ptr = microbatch_scopes_[section_id][microbatch_id]->Var(var->Name()); VLOG(3) << "Create variable " << var->Name() << " for section " @@ -244,7 +254,7 @@ void PipelineTrainer::Finalize() { const LoDTensor& minibatch_tensor = minibatch_ptr->Get(); TensorCopy(*static_cast(&minibatch_tensor), places_[0], static_cast(root_tensor)); - VLOG(4) << "Copy persitable var " << var->Name() << " to root scope"; + VLOG(3) << "Copy persitable var " << var->Name() << " to root scope"; } } } diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 03b7afbb8771fadbe07a352497fa69a299928cf7..068ed73759eac52a2dfcc1d3dd3cd003efb869ae 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -32,6 +32,7 @@ namespace framework { std::atomic SectionWorker::cpu_id_(0); std::mutex SectionWorker::thread_mutex; +std::mutex SectionWorker::cout_mutex; std::condition_variable SectionWorker::thread_condition; bool SectionWorker::threads_completed = false; uint64_t SectionWorker::batch_id_(0); @@ -103,9 +104,14 @@ void SectionWorker::TrainFiles() { } #endif + platform::Timer batch_timer; + if (thread_id_ == 0) { while (true) { // Start a minibatch. + // real number of microbatches run + int real_microbatch_num = 0; + batch_timer.Start(); for (int i = 0; i < num_microbatches_; ++i) { try { for (auto& op : ops_) { @@ -137,17 +143,21 @@ void SectionWorker::TrainFiles() { VLOG(3) << "called notify all"; thread_condition.notify_all(); VLOG(0) << "EOF encountered"; - return; + break; } - if (i == 0) { + { + real_microbatch_num += 1; + batch_id_ += 1; VLOG(3) << "called notify all"; std::unique_lock lk(thread_mutex); - batch_id_ += 1; thread_condition.notify_all(); } } + dev_ctx_->Wait(); + + VLOG(0) << "real_microbatch_num for thread 0 " << real_microbatch_num; // backward pass - for (int i = 0; i < num_microbatches_; ++i) { + for (int i = 0; i < real_microbatch_num; ++i) { for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kBackward) || @@ -163,6 +173,12 @@ void SectionWorker::TrainFiles() { } } } + dev_ctx_->Wait(); + if (real_microbatch_num == 0) { + batch_timer.Pause(); + VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + return; + } // update pass for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); @@ -177,33 +193,45 @@ void SectionWorker::TrainFiles() { } } dev_ctx_->Wait(); - } - } else { - while (true) { + batch_timer.Pause(); + VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); { - PADDLE_ENFORCE_LE( - local_batch_id_, batch_id_, - platform::errors::InvalidArgument( - "local_batch_id_ (%d) must be less than or equal to " - "batch_id_ (%d)", - local_batch_id_, batch_id_)); std::unique_lock lk(thread_mutex); - if (local_batch_id_ == batch_id_ && !threads_completed) { - thread_condition.wait(lk); - } - VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " - << local_batch_id_ << " batch_id_ " << batch_id_; if (threads_completed) { - VLOG(3) << "thread " << thread_id_ << " completed."; - lk.unlock(); - threads_completed = false; return; } - lk.unlock(); - local_batch_id_ += 1; } + } + } else { + while (true) { // forward pass: + bool local_completed = false; + int real_microbatch_num = 0; for (int i = 0; i < num_microbatches_; ++i) { + { + PADDLE_ENFORCE_LE( + local_batch_id_, batch_id_, + platform::errors::InvalidArgument( + "local_batch_id_ (%d) must be less than or equal to " + "batch_id_ (%d)", + local_batch_id_, batch_id_)); + std::unique_lock lk(thread_mutex); + if (local_batch_id_ == batch_id_ && !threads_completed) { + thread_condition.wait(lk); + } + VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " + << local_batch_id_ << " batch_id_ " << batch_id_; + if (threads_completed) { + VLOG(3) << "thread " << thread_id_ << " completed."; + lk.unlock(); + threads_completed = false; + local_completed = true; + break; + } + lk.unlock(); + local_batch_id_ += 1; + real_microbatch_num += 1; + } for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); // We run op with op_role = kLRSched only for the first microbatch @@ -227,8 +255,9 @@ void SectionWorker::TrainFiles() { } } } + dev_ctx_->Wait(); // backward pass - for (int i = 0; i < num_microbatches_; ++i) { + for (int i = 0; i < real_microbatch_num; ++i) { for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kBackward) || @@ -244,7 +273,11 @@ void SectionWorker::TrainFiles() { } } } + dev_ctx_->Wait(); // update pass + if (real_microbatch_num == 0) { + return; + } for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kOptimize)) { @@ -258,6 +291,9 @@ void SectionWorker::TrainFiles() { } } dev_ctx_->Wait(); + if (local_completed) { + return; + } } } } @@ -307,14 +343,20 @@ void SectionWorker::TrainFilesWithProfiler() { #endif if (thread_id_ == 0) { + struct timeval start; + struct timeval end; + struct timeval micro_start; + struct timeval micro_end; while (true) { // Start a minibatch. - // int batch_size = 0; batch_timer.Start(); + int real_microbatch_num = 0; for (int i = 0; i < num_microbatches_; ++i) { try { int op_idx = 0; + gettimeofday(µ_start, NULL); for (auto& op : ops_) { + gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); // We run op with op_role = kLRSched only for the first microbatch // to avoid increasing the @LR_DECAY_STEP@ multiple times. @@ -335,7 +377,9 @@ void SectionWorker::TrainFilesWithProfiler() { DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, gc.get()); } + cudaDeviceSynchronize(); timeline.Pause(); + gettimeofday(&end, NULL); auto time = timeline.ElapsedUS(); op_total_time[op_idx] += time; if (time > op_max_time[op_idx]) { @@ -346,9 +390,30 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << i << "]:OP[" << op->Type() + << "]:START[" << start.tv_sec * 1e6 + start.tv_usec + << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" + << std::endl; + } } op_idx++; } + gettimeofday(µ_end, NULL); + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; + } } catch (platform::EOFException&) { std::unique_lock lk(thread_mutex); threads_completed = true; @@ -363,19 +428,23 @@ void SectionWorker::TrainFilesWithProfiler() { << ", mean_time: " << op_total_time[i] / op_count[i]; } VLOG(0) << "================================"; - return; + break; } - if (i == 0) { + { VLOG(3) << "called notify all"; std::unique_lock lk(thread_mutex); + real_microbatch_num += 1; batch_id_ += 1; thread_condition.notify_all(); } } + dev_ctx_->Wait(); // backward pass - for (int i = 0; i < num_microbatches_; ++i) { + for (int i = 0; i < real_microbatch_num; ++i) { int op_idx = 0; + gettimeofday(µ_start, NULL); for (auto& op : ops_) { + gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kBackward) || op_role == (static_cast(OpRole::kBackward) | @@ -388,6 +457,8 @@ void SectionWorker::TrainFilesWithProfiler() { DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, gc.get()); } + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); timeline.Pause(); auto time = timeline.ElapsedUS(); op_total_time[op_idx] += time; @@ -399,13 +470,42 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << i << "]:OP[" << op->Type() + << "]:START[" << start.tv_sec * 1e6 + start.tv_usec + << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" + << std::endl; + } } op_idx++; } + gettimeofday(µ_end, NULL); + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; + } + } + dev_ctx_->Wait(); + if (real_microbatch_num == 0) { + batch_timer.Pause(); + VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + return; } // update pass int op_idx = 0; + gettimeofday(µ_start, NULL); for (auto& op : ops_) { + gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kOptimize)) { VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ @@ -416,6 +516,8 @@ void SectionWorker::TrainFilesWithProfiler() { DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], op.get(), unused_vars_, gc.get()); } + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); timeline.Pause(); auto time = timeline.ElapsedUS(); op_total_time[op_idx] += time; @@ -427,48 +529,88 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << num_microbatches_ << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + } } op_idx++; } + gettimeofday(µ_end, NULL); + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" + << std::endl; + } dev_ctx_->Wait(); batch_timer.Pause(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); - } - } else { - while (true) { { - PADDLE_ENFORCE_LE( - local_batch_id_, batch_id_, - platform::errors::InvalidArgument( - "local_batch_id_ (%d) must be less than or equal to " - "batch_id_ (%d)", - local_batch_id_, batch_id_)); std::unique_lock lk(thread_mutex); - if (local_batch_id_ == batch_id_ && !threads_completed) { - thread_condition.wait(lk); - } - VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " - << local_batch_id_ << " batch_id_ " << batch_id_; if (threads_completed) { - VLOG(3) << "thread " << thread_id_ << " completed."; - lk.unlock(); - VLOG(0) << "============timeline============"; - for (size_t i = 0; i < ops_.size(); ++i) { - VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i] - << ", min_time: " << op_min_time[i] - << ", mean_time: " << op_total_time[i] / op_count[i]; - } - VLOG(0) << "================================"; - threads_completed = false; return; } - lk.unlock(); - local_batch_id_ += 1; } + } + } else { + struct timeval start; + struct timeval end; + struct timeval micro_start; + struct timeval micro_end; + cudaEvent_t cu_start, cu_stop; + cudaEventCreate(&cu_start); + cudaEventCreate(&cu_stop); + bool local_completed = false; + while (true) { // forward pass: + int real_microbatch_num = 0; for (int i = 0; i < num_microbatches_; ++i) { + { + PADDLE_ENFORCE_LE( + local_batch_id_, batch_id_, + platform::errors::InvalidArgument( + "local_batch_id_ (%d) must be less than or equal to " + "batch_id_ (%d)", + local_batch_id_, batch_id_)); + std::unique_lock lk(thread_mutex); + if (local_batch_id_ == batch_id_ && !threads_completed) { + thread_condition.wait(lk); + } + VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " + << local_batch_id_ << " batch_id_ " << batch_id_; + if (threads_completed) { + local_completed = true; + VLOG(3) << "thread " << thread_id_ << " completed."; + lk.unlock(); + VLOG(0) << "============timeline============"; + for (size_t i = 0; i < ops_.size(); ++i) { + VLOG(0) << "op: " << op_name[i] + << ", max_time: " << op_max_time[i] + << ", min_time: " << op_min_time[i] + << ", mean_time: " << op_total_time[i] / op_count[i]; + } + VLOG(0) << "================================"; + break; + } + lk.unlock(); + real_microbatch_num += 1; + local_batch_id_ += 1; + } int op_idx = 0; + gettimeofday(µ_start, NULL); for (auto& op : ops_) { + gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); // We run op with op_role = kLRSched only for the first microbatch // to avoid increasing the @LR_DECAY_STEP@ multiple times. @@ -489,6 +631,8 @@ void SectionWorker::TrainFilesWithProfiler() { DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, gc.get()); } + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); timeline.Pause(); auto time = timeline.ElapsedUS(); op_total_time[op_idx] += time; @@ -500,14 +644,38 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::FWD:B[" << local_batch_id_ << "]:SEC[" + << thread_id_ << "]:SCOPE[" << i << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + } } op_idx++; } + gettimeofday(µ_end, NULL); + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; + } } + dev_ctx_->Wait(); // backward pass - for (int i = 0; i < num_microbatches_; ++i) { + for (int i = 0; i < real_microbatch_num; ++i) { int op_idx = 0; + gettimeofday(µ_start, NULL); for (auto& op : ops_) { + gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kBackward) || op_role == (static_cast(OpRole::kBackward) | @@ -520,6 +688,8 @@ void SectionWorker::TrainFilesWithProfiler() { DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, gc.get()); } + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); timeline.Pause(); auto time = timeline.ElapsedUS(); op_total_time[op_idx] += time; @@ -531,13 +701,40 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::BWD:B[" << local_batch_id_ << "]:SEC[" + << thread_id_ << "]:SCOPE[" << i << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + } } op_idx++; } + gettimeofday(µ_end, NULL); + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; + } + } + dev_ctx_->Wait(); + if (real_microbatch_num == 0) { + return; } // update pass int op_idx = 0; + gettimeofday(µ_start, NULL); for (auto& op : ops_) { + gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kOptimize)) { VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ @@ -548,6 +745,8 @@ void SectionWorker::TrainFilesWithProfiler() { DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], op.get(), unused_vars_, gc.get()); } + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); timeline.Pause(); auto time = timeline.ElapsedUS(); op_total_time[op_idx] += time; @@ -559,10 +758,34 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << num_microbatches_ << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + } } op_idx++; } + gettimeofday(µ_end, NULL); + { + std::unique_lock lk(cout_mutex); + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" + << std::endl; + } dev_ctx_->Wait(); + if (local_completed) { + return; + } } } } diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 077fe75172022c8fe501bd1143895115298417bf..1f97024d97059e2924e896cb0584e40be61e4ddc 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -223,7 +223,6 @@ class PipelineTrainer : public TrainerBase { int section_num_; int num_microbatches_; int start_cpu_core_id_; - std::vector feed_var_names_; std::vector places_; std::vector> skip_vars_; TrainerDesc trainer_desc_; diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9e2d77df777d761b6904d8916c7a35fb8e6bfaba..2dd654c35c35b2b8212e7285d7830f82e9b53a66 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -48,8 +48,9 @@ __all__ = [ 'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum', - 'LarsMomentumOptimizer', 'LambOptimizer', 'ExponentialMovingAverage', - 'PipelineOptimizer', 'LookaheadOptimizer', 'RecomputeOptimizer' + 'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer', + 'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer', + 'RecomputeOptimizer' ] @@ -3709,15 +3710,9 @@ class PipelineOptimizer(object): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) batch_size = 1 - filelist = [] # you should set your own filelist, e.g. filelist = ["dataA.txt"] - dataset = fluid.DatasetFactory().create_dataset("FileInstantDataset") - dataset.set_use_var([x,y]) - dataset.set_batch_size(batch_size) - dataset.set_filelist(filelist) data_loader.start() exe.train_from_dataset( - fluid.default_main_program(), - dataset) + fluid.default_main_program()) data_loader.reset() """ @@ -3735,7 +3730,7 @@ class PipelineOptimizer(object): "num_microbatches must be a positive value.") self._num_microbatches = num_microbatches assert start_cpu_core_id >= 0, ( - "start_cpu_core_id must be greater than or equal to 0.") + "start_cpu_core_id must be a non negative integer.") self._start_cpu_core_id = start_cpu_core_id self._place_list = None op_maker = core.op_proto_and_checker_maker @@ -3743,7 +3738,7 @@ class PipelineOptimizer(object): self._op_role_key = op_maker.kOpRoleAttrName() self._op_role_var_key = op_maker.kOpRoleVarAttrName() self._op_device_key = op_maker.kOpDeviceAttrName() - self._param_device_map = dict() + self._param_device_map = None def _create_vars(self, block, main_program): # Create vars for block, copied from main_program's global block @@ -3782,9 +3777,10 @@ class PipelineOptimizer(object): return 'Param' in op.input_names and 'Grad' in op.input_names and ( "LearningRate" in op.input_names) - def _split_program(self, main_program): + def _split_program(self, main_program, devices): """ Split a program into sections according to devices that ops run on. + The ops of the role LRSched are copied to all sections. Args: main_program (Program): the main program @@ -3792,18 +3788,27 @@ class PipelineOptimizer(object): programs = [] # Map from device to its corresponding section program info device_program_map = dict() - block = main_program.block(0) + for device in devices: + p = {'program': Program()} + device_program_map[device] = p + block = main_program.block(0) for op in block.ops: device = op.attr(self._op_device_key) - - if device not in device_program_map: - program = {"program": Program()} - device_program_map[device] = program - program = device_program_map[device] - op_desc = op.desc - ap_op = program["program"].block(0).desc.append_op() - ap_op.copy_from(op_desc) + op_role = op.attr(self._op_role_key) + if int(op_role) & int(self._op_role.LRSched): + # Copy ops of the role LRSched to all sections. + for device in device_program_map.keys(): + program = device_program_map[device] + op_desc = op.desc + ap_op = program["program"].block(0).desc.append_op() + ap_op.copy_from(op_desc) + ap_op._set_attr(self._op_device_key, device) + else: + program = device_program_map[device] + op_desc = op.desc + ap_op = program["program"].block(0).desc.append_op() + ap_op.copy_from(op_desc) for key in sorted(device_program_map.keys()): program = device_program_map[key] @@ -3833,9 +3838,8 @@ class PipelineOptimizer(object): for in_var_name in op.input_arg_names: if in_var_name == var_name: post_op.append(op) + break if post_op: - if not len(post_op) == 1: - raise ValueError("Each op can only have one post op.") return post_op[0] return None @@ -3890,60 +3894,26 @@ class PipelineOptimizer(object): def _get_data_var_info(self, block): """ Get all vars whose is_data attribute are true and then rename them. - - For PipelineTrainer, all data vars are binded to - minibatch scope, so we have to feed them to the microbatch - to avoid conflicts. The vars feeded to microbatch have to - be renamed. """ - # A map from var name to the renamed name. - raw_name_new_name_map = dict() - # Because we will create vars in block, it is more safe - # to get all var_names before iteration. - var_names = list(block.vars.keys()) - for var_name in var_names: - var = block.var(var_name) - if not var.is_data: - continue - assert var_name not in raw_name_new_name_map, ( - "{} has already been processed.".format(var_name)) - new_name = unique_name.generate(var_name) - raw_name_new_name_map[var_name] = new_name - new_var = self._create_var(block, var, new_name) - new_var.is_data = False - - # map of data to devices that that data on + # map of data vars to devices that that data on data_devices_map = dict() for op in block.ops: dev_spec = op.attr(self._op_device_key) for var_name in op.input_arg_names: - if var_name not in raw_name_new_name_map: + if "blocking_queue" in var_name: continue + var = block.var(var_name) + if not var.is_data: continue if not var_name in data_devices_map: data_devices_map[var_name] = [] if not dev_spec in data_devices_map[var_name]: data_devices_map[var_name].append(dev_spec) - new_name = raw_name_new_name_map[var_name] - #self._rename_arg(op, var_name, new_name) - return data_devices_map, raw_name_new_name_map - - def _rename_var_in_block(self, block, raw_name_new_name_map): - """ - Rename vars whose names in raw_name_new_name_map to the corresponding - new names. - """ - for op in block.ops: - if op.type == "enqueue" or op.type == "dequeue": - continue - for var_name in op.input_arg_names: - if var_name in raw_name_new_name_map: - new_name = raw_name_new_name_map[var_name] - self._rename_arg(op, var_name, new_name) + return data_devices_map def _insert_enq_deq_for_data_var(self, main_block, programs, startup, devices): """ - Insert enqueue and dequeue ops for data var + Insert enqueue and dequeue ops for data var that on other devices. Args: main_block (Block): Global block for main program @@ -3952,22 +3922,19 @@ class PipelineOptimizer(object): devices (list): List of devices in the format (dev:dev_index) """ main_program = main_block.program - data_devices_map, raw_name_new_name_map = self._get_data_var_info( - main_block) + data_devices_map = self._get_data_var_info(main_block) first_prog = programs[0]['program'] first_block = first_prog.block(0) enqueue_index = 0 - if first_block.ops[0].type == "create_py_reader" or ( - first_block.ops[1].type == "create_py_reader"): - for op in first_block.ops: - if op.type == "read": - enqueue_index += 1 - break - enqueue_index += 1 + for op in first_block.ops: + enqueue_index += 1 + if op.type == "read": + break first_dev_spec = devices[0] for var_name in data_devices_map.keys(): for device in data_devices_map[var_name]: + if device == first_dev_spec: continue # step1: generate queue for each pair of data var and device # that that data on queue_name = var_name + "_blocking_queue" @@ -4001,13 +3968,10 @@ class PipelineOptimizer(object): prog = programs[prog_index]['program'] block = prog.block(0) index = 0 - if device == first_dev_spec: - index = enqueue_index + 1 - new_name = raw_name_new_name_map[var_name] source_var = main_program.block(0).var(var_name) - new_var = self._create_var(block, source_var, new_name) + new_var = self._create_var(block, source_var, var_name) block._insert_op( - index=index, + index=0, type='dequeue', outputs={'Out': [new_var]}, attrs={ @@ -4015,7 +3979,6 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Forward, 'queue_name': queue_name, }) - self._rename_var_in_block(block, raw_name_new_name_map) def _strip_grad_suffix(self, name): """ @@ -4030,18 +3993,6 @@ class PipelineOptimizer(object): """ return name + core.grad_var_suffix() - def _update_param_device_map(self, params_grads, block): - for param_grad in params_grads: - if not param_grad[0].trainable: continue - param_name = param_grad[0].name - ops = block.ops - for op in ops: - input_arg_names = op.input_arg_names - if param_name in input_arg_names: - self._param_device_map[param_name] = op.attr( - self._op_device_key) - break - def _add_opdevice_attr_for_regularization_clip(self, block): """ Add op_device attribute for regulization and clip ops. @@ -4056,7 +4007,7 @@ class PipelineOptimizer(object): assert self._op_role_var_key in op.attr_names op_role_var = op.all_attrs()[self._op_role_var_key] assert len(op_role_var) == 2 - param_name = block.vars[op_role_var[0]].name + param_name = op_role_var[0] device = self._param_device_map[param_name] op._set_attr(self._op_device_key, device) @@ -4125,6 +4076,8 @@ class PipelineOptimizer(object): "{} has not been set.".format(op.type)) if not dev_spec in device_specs: device_specs.append(dev_spec) + sorted_device_specs = sorted(device_specs) + assert sorted_device_specs == device_specs return device_specs def _insert_enq_deq_ops_for_boundaries(self, block, origin_block, @@ -4141,6 +4094,11 @@ class PipelineOptimizer(object): var_devspec = dict() for index, op in list(enumerate(origin_block.ops)): + # skips lr-related op and vars, as we will process them later. + if int(op.attr(self._op_role_key)) & int(self._op_role.LRSched): + continue + if self._is_update_op(op): continue + cur_device_spec = op.attr(self._op_device_key) for var_name in op.input_arg_names: # i.e., lod_tensor_blocking_queue created by DataLoader, @@ -4196,82 +4154,32 @@ class PipelineOptimizer(object): }) extra_index += 1 - def _add_dequeue_ops_for_optimize(self, block, startup_program): - startup_block = startup_program.global_block() - grad_queue_map = dict() - grad_device_map = dict() - optimize_index = None - grad_names_to_dequeue = [] - - for index, op in reversed(list(enumerate(block.ops))): - device = op.attr(self._op_device_key) - # Optimizer pass - if not self._is_optimize_op(op): - optimize_index = index + 1 - break - if not self._is_update_op(op): continue - assert self._op_role_var_key in op.attr_names - op_role_var = op.all_attrs()[self._op_role_var_key] - assert len(op_role_var) == 2 - grad_name = op_role_var[1] - assert grad_name not in grad_device_map - assert grad_name not in grad_names_to_dequeue - grad_device_map[grad_name] = device - grad_names_to_dequeue.append(grad_name) - - for grad_name in grad_names_to_dequeue: - device = grad_device_map[grad_name] - grad_names = [] - grads = [] - queue_name = grad_name + "_blocking_queue" - queue_name = unique_name.generate(queue_name) - grad_queue_map[grad_name] = queue_name - ref_var = block.vars[grad_name] - queue_var = startup_block.create_var( - name=queue_name, - persistable=True, - type=core.VarDesc.VarType.RAW) - startup_block.append_op( - type='queue_generator', - attrs={ - 'names': [queue_name], - 'capacity': self._num_microbatches - }) - orig_var_name = self._strip_grad_suffix(grad_name) - for _ in range(self._num_microbatches): - u_name = unique_name.generate(orig_var_name) - u_grad_name = self._append_grad_suffix(u_name) - grad_var = self._create_var(block, ref_var, u_grad_name) - grad_names.append(u_grad_name) - grads.append(grad_var) - block._insert_op( - index=optimize_index, - type='dequeue', - outputs={'Out': grads}, - attrs={ - self._op_device_key: device, - 'queue_name': queue_name, - self._op_role_key: self._op_role.Optimize - }) - block._insert_op( - index=optimize_index + 1, - type='sum', - inputs={'X': grad_names}, - outputs={'Out': ref_var}, + def _clear_gradients(self, main_block): + """ + Clear gradients at the begining of each run of a minibatch. + """ + for param_name in self._param_device_map: + grad_name = self._append_grad_suffix(param_name) + param_var = main_block.vars[param_name] + grad_var = main_block.vars[grad_name] + device = self._param_device_map[param_name] + main_block._insert_op( + index=0, + type='fill_constant', + inputs={}, + outputs={'Out': [grad_var]}, attrs={ + 'shape': grad_var.shape, + 'dtype': grad_var.dtype, + 'value': float(0), self._op_device_key: device, - self._op_role_key: self._op_role.Optimize + self._op_role_key: self._op_role.Optimize.LRSched, }) - return grad_queue_map - def _insert_enq_deq_ops_for_update(self, block, startup_program): + def _accumulate_gradients(self, block): """ - Insert enqueue and dequeue ops for gradients of parameters. + Accumulate the graident generated in microbatch to the one in mini-batch. """ - startup_block = startup_program.global_block() - grad_queue_map = self._add_dequeue_ops_for_optimize(block, - startup_program) - for index, op in reversed(list(enumerate(block.ops))): offset = index device = op.attr(self._op_device_key) @@ -4298,19 +4206,25 @@ class PipelineOptimizer(object): if len(op_role_var) == 0: continue assert len(op_role_var) % 2 == 0 + offset = index for i in range(0, len(op_role_var), 2): grad_name = op_role_var[i + 1] grad_var = block.vars[grad_name] - assert grad_name in grad_queue_map - queue_name = grad_queue_map[grad_name] + param_name = op_role_var[i] + param_var = block.vars[param_name] + new_var_name = unique_name.generate(param_name) + new_var_name = self._append_grad_suffix(new_var_name) + new_var = self._create_var(block, grad_var, new_var_name) + self._rename_arg(op, grad_name, new_var_name) block._insert_op( index=offset + 1, - type='enqueue', - inputs={'X': block.vars[grad_name]}, + type='sum', + inputs={'X': [grad_var, new_var]}, + outputs={'Out': grad_var}, attrs={ - 'queue_name': queue_name, self._op_device_key: device, - self._op_role_key: self._op_role.Backward + self._op_role_key: self._op_role.Backward, + self._op_role_var_key: op_role_var }) offset += 1 @@ -4333,6 +4247,7 @@ class PipelineOptimizer(object): def _get_device_info(self, block): for op in block.ops: + if not op._has_kernel(op.type): continue op_device = op.attr(self._op_device_key) return op_device @@ -4438,14 +4353,16 @@ class PipelineOptimizer(object): startup_program = default_startup_program() optimize_ops, params_grads = self._optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) - self._update_param_device_map(params_grads, main_block) + self._param_device_map = self._optimizer._param_device_map # Step1: add default op_device attribute for regulization and clip ops self._add_opdevice_attr_for_regularization_clip(main_block) # Step2: add default op_device attribute for ops whose op_device - # attribute have not been set yet. + # attribute have not been set yet. Then check all ops have the + # op_device attribute. self._add_default_opdevice_attr(main_block) + device_specs = self._check_validation(main_block) # Step3: add enqueue and dequeue ops between section boundaries @@ -4454,8 +4371,10 @@ class PipelineOptimizer(object): self._insert_enq_deq_ops_for_boundaries(main_block, origin_main_block, startup_program) - # Step4: add a pair of enqueue and dequeueN for parameter gradients - self._insert_enq_deq_ops_for_update(main_block, startup_program) + # Step4: accumulate gradients during backward + # and clear them after update + self._clear_gradients(main_block) + self._accumulate_gradients(main_block) main_program = main_block.program @@ -4474,18 +4393,11 @@ class PipelineOptimizer(object): # Step5: split program into sections and add pairs of # enqueue and dequeue ops for data var. - if len(place_list) == 0: - program_list = [] - ptmp = { - "program": main_program, - "input_set": set(), - "output_set": set() - } - program_list.append(ptmp) - else: - program_list = self._split_program(main_program) - for p in program_list: - self._create_vars(p["program"].block(0), main_program) + if len(place_list) <= 1: + raise ValueError("Run on one device, do not use pipeline.") + program_list = self._split_program(main_program, device_specs) + for p in program_list: + self._create_vars(p["program"].block(0), main_program) self._insert_enq_deq_for_data_var(main_block, program_list, startup_program, device_specs)