From 5f3aaafcfae127c62e755bcd5e1923161855eea8 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Wed, 23 Sep 2020 13:03:11 +0000 Subject: [PATCH] update, test=develop --- paddle/fluid/framework/device_worker.h | 8 - paddle/fluid/framework/pipeline_trainer.cc | 215 +----------------- paddle/fluid/framework/section_worker.cc | 186 ++++++--------- paddle/fluid/framework/trainer.h | 13 -- .../operators/collective/c_recv_op.cu.cc | 5 +- 5 files changed, 75 insertions(+), 352 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 2f047eb6de9..e2465848681 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -414,7 +414,6 @@ class HeterCpuWorker : public HogwildWorker { #if defined(PADDLE_WITH_NCCL) class SectionWorker : public DeviceWorker { public: - // SectionWorker() { local_batch_id_ = 0; } SectionWorker() {} ~SectionWorker() override {} @@ -430,7 +429,6 @@ class SectionWorker : public DeviceWorker { const platform::Place& place() const { return place_; } - // void SetSectionIndex(int section_id) { section_id_ = section_id; } void SetDeviceIndex(int tid) override {} void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetMicrobatchNum(int num) { num_microbatches_ = num; } @@ -442,7 +440,6 @@ class SectionWorker : public DeviceWorker { skip_vars_ = skip_vars; } void SetStartCpuCoreId(int id) { cpu_id_ = id; } - // static void ResetBatchId() { batch_id_ = 0; } protected: void AutoSetCPUAffinity(bool reuse); @@ -455,13 +452,8 @@ class SectionWorker : public DeviceWorker { const Scope* minibatch_scope_; std::vector> ops_; - // static std::mutex thread_mutex; - // static std::mutex cout_mutex; - // static std::condition_variable thread_condition; - // static bool threads_completed; std::shared_ptr program_; static uint64_t batch_id_; - // uint64_t local_batch_id_; platform::DeviceContext* dev_ctx_ = nullptr; }; diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 62429b7bee1..e335526a14a 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -28,71 +28,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, num_microbatches_ = section_params.num_microbatches(); VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; trainer_desc_ = trainer_desc; - start_cpu_core_id_ = section_params.start_cpu_core_id(); + auto cpu_core_id = section_params.start_cpu_core_id(); - // SetDataset(dataset); ParseDumpConfig(trainer_desc); - // get filelist from trainer_desc here - // const std::vector readers = - // dataset->GetReaders(); - // VLOG(3) << "Number of program sections: " << section_num_; - // VLOG(3) << "readers num: " << readers.size(); - // int num_readers = readers.size(); - // PADDLE_ENFORCE_EQ(num_readers, 1, - // platform::errors::InvalidArgument( - // "Number of dataset readers for pipeline " - // "must be 1 now, but the value you give is %d.", - // num_readers)); - // auto* reader = readers[0]; - - // workers_.resize(section_num_); - // for (int i = 0; i < section_num_; ++i) { - // const auto& section_config = section_params.section_config(i); - // platform::Place place; - // int place_id = section_config.place_id(); - // switch (section_config.place()) { - // case SectionConfig::CPUPlace: - // place = platform::CPUPlace(); - // break; - // case SectionConfig::CUDAPlace: - // // Note that one section has at most one GPU place in one pipeline - // PADDLE_ENFORCE_GE( - // place_id, 0, - // platform::errors::InvalidArgument( - // "The place_id value for CUDAPlace shoud be greater " - // "than or equal to 0, but the value you give is %d.", - // place_id)); - // place = platform::CUDAPlace(place_id); - // break; - // case SectionConfig::CUDAPinnedPlace: - // place = platform::CUDAPinnedPlace(); - // break; - // default: - // PADDLE_ENFORCE_NOT_NULL(nullptr, - // platform::errors::InvalidArgument( - // "Unkown place type in SectionConfig: %d", - // section_config.place())); - // } - // places_.emplace_back(place); - // VLOG(3) << "Device worker place: " << place << ", device id: " << place_id - // << ", section: " << i; - - // workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( - // trainer_desc.device_worker_name()); - // auto this_worker = - // std::dynamic_pointer_cast( - // workers_[i]); - // if (i == 0) { - // // we only set reader for the first section - // this_worker->SetDataFeed(reader); - // this_worker->SetReaderPlace(place); - // } - // this_worker->SetThreadIndex(i); - // this_worker->SetSectionIndex(i); - // this_worker->SetPlace(place); - // this_worker->Initialize(trainer_desc); - // this_worker->SetMicrobatchNum(num_microbatches_); - //} const auto& section_config = section_params.section_config(); int place_id = section_config.place_id(); PADDLE_ENFORCE_GE(place_id, 0, @@ -108,7 +46,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, this_worker->SetPlace(place_); this_worker->Initialize(trainer_desc); this_worker->SetMicrobatchNum(num_microbatches_); - this_worker->SetStartCpuCoreId(start_cpu_core_id_); + this_worker->SetStartCpuCoreId(cpu_core_id); // set debug here SetDebug(trainer_desc.debug()); @@ -118,7 +56,6 @@ void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { if (need_dump_field_) { InitDumpEnv(); } - VLOG(3) << "init other env done."; } std::string PipelineTrainer::GetDumpPath(int tid) { @@ -135,51 +72,6 @@ void PipelineTrainer::InitDumpEnv() { } } -// void PipelineTrainer::CopyParameters(int section_id, int microbatch_id, -// const ProgramDesc& program, -// const platform::Place& place) { -// auto& global_block = program.Block(0); -// std::map param_map; -// for (auto& var : global_block.AllVars()) { -// if (var->Persistable()) { -// param_map[var->Name()] = 1; -// } -// } -// for (auto& var : global_block.AllVars()) { -// bool is_param_grad = false; -// size_t pos = 0; -// if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) { -// auto prefix_name = var->Name().substr(0, pos); -// if (param_map.find(prefix_name) != param_map.end()) { -// is_param_grad = true; -// } -// } -// VLOG(3) << "Var name: " << var->Name(); -// if ((var->Persistable() || is_param_grad) && microbatch_id == 0) { -// auto* ptr = root_scope_->FindVar(var->Name()); -// auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name()); -// VLOG(3) << "Create persistable var " << var->Name() << " for minibatch -// " -// << section_id << ", which pointer is " << new_ptr; -// InitializeVariable(new_ptr, var->GetType()); -// if (is_param_grad) { -// continue; -// } -// const LoDTensor& root_tensor = ptr->Get(); -// LoDTensor* minibatch_tensor = new_ptr->GetMutable(); -// TensorCopy(*static_cast(&root_tensor), place, -// static_cast(minibatch_tensor)); -// } else if (!var->Persistable() && !is_param_grad) { -// auto* ptr = -// microbatch_scopes_[section_id][microbatch_id]->Var(var->Name()); -// VLOG(3) << "Create variable " << var->Name() << " for section " -// << section_id << " microbatch " << microbatch_id -// << ", which pointer is " << ptr; -// InitializeVariable(ptr, var->GetType()); -// } -// } -// } - void PipelineTrainer::CopyParameters(int microbatch_id, const ProgramDesc& program, const platform::Place& place) { @@ -190,6 +82,7 @@ void PipelineTrainer::CopyParameters(int microbatch_id, param_map[var->Name()] = 1; } } + for (auto& var : global_block.AllVars()) { bool is_param_grad = false; size_t pos = 0; @@ -199,7 +92,6 @@ void PipelineTrainer::CopyParameters(int microbatch_id, is_param_grad = true; } } - VLOG(3) << "Var name: " << var->Name(); if (is_param_grad && microbatch_id == 0) { auto* ptr = minibatch_scope_->Var(var->Name()); InitializeVariable(ptr, var->GetType()); @@ -207,149 +99,52 @@ void PipelineTrainer::CopyParameters(int microbatch_id, << ", which pointer is " << ptr; } else if (!var->Persistable() && !is_param_grad) { auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); - VLOG(3) << "Create variable " << var->Name() << " microbatch " + VLOG(3) << "Create variable " << var->Name() << " for microbatch " << microbatch_id << ", which pointer is " << ptr; InitializeVariable(ptr, var->GetType()); } } } -// void PipelineTrainer::GetSkipVars(int section_id, const ProgramDesc& program) -// { -// auto& global_block = program.Block(0); -// for (auto& op : global_block.AllOps()) { -// if (op->Type() != "enqueue") { -// continue; -// } -// auto input_arg_names = op->InputArgumentNames(); -// PADDLE_ENFORCE_EQ(input_arg_names.size(), 1, -// platform::errors::InvalidArgument( -// "Number of input arguments for enqueue op must be -// 1, " -// "but the value is %d.", -// input_arg_names.size())); -// std::string input_arg_name = input_arg_names[0]; -// if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) { -// skip_vars_[section_id].emplace_back(input_arg_name); -// VLOG(3) << "add skip var name: " << input_arg_name; -// } -// } -// } - -// void PipelineTrainer::GetSkipVars(const ProgramDesc& program) { -// auto& global_block = program.Block(0); -// for (auto& op : global_block.AllOps()) { -// if (op->Type() != "c_send") { -// continue; -// } -// auto input_arg_names = op->InputArgumentNames(); -// PADDLE_ENFORCE_EQ(input_arg_names.size(), 1, -// platform::errors::InvalidArgument( -// "Number of input arguments for c_send op must be 1, -// " -// "but the value given is %d.", -// input_arg_names.size())); -// std::string input_arg_name = input_arg_names[0]; -// if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) { -// skip_vars_.emplace_back(input_arg_name); -// VLOG(3) << "add skip var name: " << input_arg_name; -// } -// } -// } - void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) { PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( "root_scope_ can not be nullptr")); - // auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id(); - // SectionWorker::cpu_id_.store(start_cpu_id); - // minibatch_scopes_.resize(section_num_); - // microbatch_scopes_.resize(section_num_); - // minibatch_scopes_.resize(1); microbatch_scopes_.resize(num_microbatches_); - // skip_vars_.resize(section_num_); VLOG(3) << "Create minibatch and microbatch scopes..."; - // for (int i = 0; i < section_num_; ++i) { minibatch_scope_ = &root_scope_->NewScope(); std::shared_ptr program; program.reset(new ProgramDesc( trainer_desc_.section_param().section_config().program_desc())); - // trainer_desc_.section_param().section_config(i).program_desc())); - // microbatch_scopes_[i].resize(num_microbatches_); for (int j = 0; j < num_microbatches_; ++j) { - // microbatch_scopes_[j] = &minibatch_scopes_[i]->NewScope(); microbatch_scopes_[j] = &minibatch_scope_->NewScope(); - // CopyParameters(i, j, *program, places_[i]); CopyParameters(j, *program, place_); } - // GetSkipVars(i, *program); - // GetSkipVars(*program); - // } - // for (int i = 0; i < section_num_; ++i) { auto this_worker = std::dynamic_pointer_cast(worker_); - // workers_[i]); this_worker->SetRootScope(root_scope_); this_worker->SetMinibatchScope(minibatch_scope_); - // this_worker->SetMicrobatchScopes(microbatch_scopes_[i]); this_worker->SetMicrobatchScopes(microbatch_scopes_); - // this_worker->SetSkipVars(skip_vars_[i]); - //} } void PipelineTrainer::Run() { - VLOG(3) << "Going to run"; - // for (int i = 0; i < section_num_; ++i) { + VLOG(5) << "Going to run PipelineTrainer::Run()"; if (!debug_) { section_thread_ = std::thread(&DeviceWorker::TrainFiles, worker_.get()); - // section_threads_.push_back( - // std::thread(&DeviceWorker::TrainFiles, workers_.get())); - // std::thread(&DeviceWorker::TrainFiles, workers_[i].get())); } else { section_thread_ = std::thread(&DeviceWorker::TrainFilesWithProfiler, worker_.get()); - // section_threads_.push_back(std::thread( - // &DeviceWorker::TrainFilesWithProfiler, workers_.get())); - // &DeviceWorker::TrainFilesWithProfiler, workers_[i].get())); } - //} } void PipelineTrainer::Finalize() { - // for (auto& th : section_threads_) { - // th.join(); - //} section_thread_.join(); if (need_dump_field_) { FinalizeDumpEnv(); } - // VLOG(3) << "copying back parameters. "; - // for (int i = 0; i < section_num_; ++i) { - // std::shared_ptr program; - // program.reset(new ProgramDesc( - // trainer_desc_.section_param().section_config(i).program_desc())); - // for (int j = 0; j < num_microbatches_; ++j) { - // auto& global_block = program->Block(0); - // for (auto& var : global_block.AllVars()) { - // if (var->Persistable()) { - // auto* ptr = root_scope_->FindVar(var->Name()); - // LoDTensor* root_tensor = ptr->GetMutable(); - // auto* minibatch_ptr = minibatch_scopes_[i]->Var(var->Name()); - // const LoDTensor& minibatch_tensor = - // minibatch_ptr->Get(); - // TensorCopy(*static_cast(&minibatch_tensor), - // places_[0], - // static_cast(root_tensor)); - // VLOG(3) << "Copy persitable var " << var->Name() << " to root - // scope"; - // } - // } - // } - // } root_scope_->DropKids(); - // SectionWorker::ResetBatchId(); } Scope* PipelineTrainer::GetWorkerScope(int thread_id) { diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 6d342092f86..365b29e30f6 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -30,28 +30,19 @@ limitations under the License. */ namespace paddle { namespace framework { -// std::atomic SectionWorker::cpu_id_(0); -// std::mutex SectionWorker::thread_mutex; -// std::mutex SectionWorker::cout_mutex; -// std::condition_variable SectionWorker::thread_condition; -// bool SectionWorker::threads_completed = false; uint64_t SectionWorker::batch_id_(0); void SectionWorker::Initialize(const TrainerDesc& desc) { dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); program_.reset( new ProgramDesc(desc.section_param().section_config().program_desc())); - // desc.section_param().section_config(section_id_).program_desc())); for (auto& op_desc : program_->Block(0).AllOps()) { ops_.push_back(OpRegistry::CreateOp(*op_desc)); } } void SectionWorker::AutoSetCPUAffinity(bool reuse) { - // int thread_cpu_id = cpu_id_.fetch_add(1); - unsigned concurrency_cap = std::thread::hardware_concurrency(); - // unsigned proc = thread_cpu_id; unsigned proc = cpu_id_; if (proc >= concurrency_cap) { @@ -61,7 +52,6 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) { LOG(INFO) << "All " << concurrency_cap << " CPUs have been set affinities. Fail to set " << cpu_id_ << "th thread."; - // << thread_cpu_id << "th thread"; return; } } @@ -80,13 +70,12 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) { (0 == CPU_ISSET(proc, &mask))) { LOG(WARNING) << "Fail to set thread affinity to CPU " << proc; } - // VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc; VLOG(3) << "Set " << cpu_id_ << "th thread affinity to CPU " << proc; } void SectionWorker::TrainFiles() { - VLOG(3) << "begin section_worker TrainFiles"; - // AutoSetCPUAffinity(true); + VLOG(5) << "begin section_worker TrainFiles"; + AutoSetCPUAffinity(true); int64_t max_memory_size = 0; std::unique_ptr gc; @@ -109,12 +98,6 @@ void SectionWorker::TrainFiles() { #endif platform::Timer batch_timer; - - // if (thread_id_ == 0) { - // while (true) { - // Start a minibatch. - // real number of microbatches run - // int real_microbatch_num = 0; batch_timer.Start(); for (int i = 0; i < num_microbatches_; ++i) { try { @@ -130,7 +113,8 @@ void SectionWorker::TrainFiles() { op_role == (static_cast(OpRole::kForward) | static_cast(OpRole::kLoss)); if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { - VLOG(3) << "running an op " << op->Type() << " for scope " << i; + VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " + << i; op->Run(*microbatch_scopes_[i], place_); if (gc) { DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, @@ -139,19 +123,10 @@ void SectionWorker::TrainFiles() { } } } catch (platform::EOFException& e) { - // std::unique_lock lk(thread_mutex); - // threads_completed = true; - VLOG(3) << "thread completed."; - // VLOG(3) << "called notify all"; - // thread_condition.notify_all(); - VLOG(3) << "EOF encountered"; - // throw platform::EOFException(); - // throw e; - PADDLE_THROW_EOF(); - break; + VLOG(3) << "EOF encountered and completed."; + throw; } } - dev_ctx_->Wait(); // backward pass for (int i = 0; i < num_microbatches_; ++i) { @@ -160,7 +135,8 @@ void SectionWorker::TrainFiles() { if (op_role == static_cast(OpRole::kBackward) || op_role == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))) { - VLOG(3) << "running an op " << op->Type() << " for scope " << i; + VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " + << i; op->Run(*microbatch_scopes_[i], place_); if (gc) { DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, @@ -169,30 +145,28 @@ void SectionWorker::TrainFiles() { } } } - dev_ctx_->Wait(); + // update pass for (auto& op : ops_) { int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kOptimize)) { - VLOG(3) << "running an op " << op->Type() << " for minibatch scope"; + VLOG(3) << "Update: running op " << op->Type(); op->Run(*microbatch_scopes_[0], place_); if (gc) { - for (int i = 0; i < num_microbatches_; ++i) { - DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, - gc.get()); - } + DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, + gc.get()); } } } dev_ctx_->Wait(); batch_timer.Pause(); - VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + VLOG(0) << "batch: " << batch_id_ << ", time: " << batch_timer.ElapsedUS(); ++batch_id_; } void SectionWorker::TrainFilesWithProfiler() { - VLOG(3) << "begin section_worker TrainFiles with profiler"; - // AutoSetCPUAffinity(true); + VLOG(5) << "begin section_worker TrainFiles with profiler"; + AutoSetCPUAffinity(true); platform::Timer batch_timer; platform::Timer timeline; @@ -216,7 +190,6 @@ void SectionWorker::TrainFilesWithProfiler() { int64_t max_memory_size = 0; std::unique_ptr gc; - // const std::vector keep_vars; auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { @@ -235,14 +208,13 @@ void SectionWorker::TrainFilesWithProfiler() { } #endif - // if (thread_id_ == 0) { struct timeval start; struct timeval end; struct timeval micro_start; struct timeval micro_end; + // Start a minibatch. batch_timer.Start(); - // int real_microbatch_num = 0; for (int i = 0; i < num_microbatches_; ++i) { try { int op_idx = 0; @@ -260,9 +232,8 @@ void SectionWorker::TrainFilesWithProfiler() { op_role == (static_cast(OpRole::kForward) | static_cast(OpRole::kLoss)); if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { - // VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ - // << " for scope " << i; - VLOG(3) << "running an op " << op->Type() << " for scope " << i; + VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " + << i; timeline.Start(); op->Run(*microbatch_scopes_[i], place_); if (gc) { @@ -282,32 +253,26 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; - { - // std::unique_lock lk(cout_mutex); - std::cout << std::fixed; - std::cout.precision(0); - std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:SCOPE[" << i << "]:OP[" << op->Type() << "]:START[" - << start.tv_sec * 1e6 + start.tv_usec << "]:END[" - << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; - } + + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::FWD:B[" << batch_id_ << "]:SCOPE[" << i << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } op_idx++; } + gettimeofday(µ_end, NULL); - { - // std::unique_lock lk(cout_mutex); - std::cout << std::fixed; - std::cout.precision(0); - std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" - << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" - << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" - << std::endl; - } + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!FWD:B[" << batch_id_ << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" + << std::endl; } catch (platform::EOFException& e) { - VLOG(3) << "thread completed."; - VLOG(0) << "EOF encountered"; + VLOG(0) << "EOF encountered, and completed"; VLOG(0) << "============timeline============"; for (size_t i = 0; i < ops_.size(); ++i) { VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i] @@ -315,11 +280,10 @@ void SectionWorker::TrainFilesWithProfiler() { << ", mean_time: " << op_total_time[i] / op_count[i]; } VLOG(0) << "================================"; - throw e; - break; + throw; } } - dev_ctx_->Wait(); + // backward pass for (int i = 0; i < num_microbatches_; ++i) { int op_idx = 0; @@ -330,7 +294,8 @@ void SectionWorker::TrainFilesWithProfiler() { if (op_role == static_cast(OpRole::kBackward) || op_role == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))) { - VLOG(3) << "running an op " << op->Type() << " for scope " << i; + VLOG(3) << "Backward: running an op " << op->Type() + << " for micro-batch " << i; timeline.Start(); op->Run(*microbatch_scopes_[i], place_); if (gc) { @@ -350,35 +315,25 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; - { - // std::unique_lock lk(cout_mutex); - std::cout << std::fixed; - std::cout.precision(0); - std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:SCOPE[" << i << "]:OP[" << op->Type() << "]:START[" - << start.tv_sec * 1e6 + start.tv_usec << "]:END[" - << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; - } + + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::BWD:B[" << batch_id_ << "]:SCOPE[" << i << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } op_idx++; } + gettimeofday(µ_end, NULL); - { - // std::unique_lock lk(cout_mutex); - std::cout << std::fixed; - std::cout.precision(0); - std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" - << std::endl; - } + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!BWD:B[" << batch_id_ << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; } - dev_ctx_->Wait(); - // if (real_microbatch_num == 0) { - // batch_timer.Pause(); - // VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); - // return; - // } + // update pass int op_idx = 0; gettimeofday(µ_start, NULL); @@ -386,15 +341,12 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(&start, NULL); int op_role = op->Attr(std::string("op_role")); if (op_role == static_cast(OpRole::kOptimize)) { - VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ - << " for minibatch scope"; + VLOG(3) << "Update: running op " << op->Type(); timeline.Start(); op->Run(*microbatch_scopes_[0], place_); if (gc) { - for (int i = 0; i < num_microbatches_; ++i) { - DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, - gc.get()); - } + DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, + gc.get()); } cudaDeviceSynchronize(); gettimeofday(&end, NULL); @@ -409,31 +361,27 @@ void SectionWorker::TrainFilesWithProfiler() { } op_count[op_idx] += 1; op_total_time[op_idx] += time; - { - std::cout << std::fixed; - std::cout.precision(0); - std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:SCOPE[" << num_microbatches_ << "]:OP[" << op->Type() - << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" - << std::endl; - } + + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "::UPD:B[" << batch_id_ << "]:OP[" << op->Type() + << "]:START[" << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } op_idx++; } gettimeofday(µ_end, NULL); - { - std::cout << std::fixed; - std::cout.precision(0); - std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:START[" - << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" - << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; - } + std::cout << std::fixed; + std::cout.precision(0); + std::cout << "!!UPD:B[" << batch_id_ << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; dev_ctx_->Wait(); batch_timer.Pause(); - VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + VLOG(0) << "batch: " << batch_id_ << ", time: " << batch_timer.ElapsedUS(); ++batch_id_; } + } // namespace framework } // namespace paddle #endif diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index b66ab9e4131..84c6bd3768f 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -221,32 +221,19 @@ class PipelineTrainer : public TrainerBase { void GetSkipVars(const ProgramDesc& main_program); protected: - // int section_num_; int num_microbatches_; - int start_cpu_core_id_; - // std::vector places_; platform::Place place_; - // std::vector> skip_vars_; std::vector skip_vars_; TrainerDesc trainer_desc_; - // std::vector section_threads_; std::thread section_thread_; - // worker: [section_id] - // std::vector> workers_; std::shared_ptr worker_; - // minibatch_scopes_: [section_id] - // std::vector minibatch_scopes_; Scope* minibatch_scope_; - // microbatch_scopes_: [section_id][microbatch_id] - // std::vector> microbatch_scopes_; // microbatch_scopes_: [microbatch_id] std::vector microbatch_scopes_; void CopyParameters(int microbatch_id, const ProgramDesc& program, const platform::Place& place); - // bool isPersistableVarGrad(std::string name); - // bool isPersistable(VarDesc* var); }; #endif diff --git a/paddle/fluid/operators/collective/c_recv_op.cu.cc b/paddle/fluid/operators/collective/c_recv_op.cu.cc index 14e7934cc4c..e6272446156 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cu.cc @@ -73,8 +73,9 @@ class CRecvOpCUDAKernel : public framework::OpKernel { } else { stream = comm->stream(); } - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( - numel_ptr, 1, ncclInt, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(static_cast(numel_ptr), 1, ncclInt, + peer, comm->comm(), stream)); PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost)); VLOG(0) << "numel:" << numel; -- GitLab