From 9f6aaa3cd107e7a5841f18772475070587810b27 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Tue, 11 Aug 2020 03:58:01 +0000 Subject: [PATCH] reformat code, test=develop --- paddle/fluid/framework/pipeline_trainer.cc | 2 +- paddle/fluid/framework/section_worker.cc | 112 +++++++++++++-------- python/paddle/fluid/optimizer.py | 17 ++-- 3 files changed, 79 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index cbed021ee82..b827435508f 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -254,7 +254,7 @@ void PipelineTrainer::Finalize() { const LoDTensor& minibatch_tensor = minibatch_ptr->Get(); TensorCopy(*static_cast(&minibatch_tensor), places_[0], static_cast(root_tensor)); - VLOG(4) << "Copy persitable var " << var->Name() << " to root scope"; + VLOG(3) << "Copy persitable var " << var->Name() << " to root scope"; } } } diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index bbf35933900..068ed73759e 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -154,6 +154,7 @@ void SectionWorker::TrainFiles() { } } dev_ctx_->Wait(); + VLOG(0) << "real_microbatch_num for thread 0 " << real_microbatch_num; // backward pass for (int i = 0; i < real_microbatch_num; ++i) { @@ -391,11 +392,13 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[op_idx] += time; { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); - std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i - << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << i << "]:OP[" << op->Type() + << "]:START[" << start.tv_sec * 1e6 + start.tv_usec + << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" + << std::endl; } } op_idx++; @@ -403,11 +406,13 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(µ_end, NULL); { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; } } catch (platform::EOFException&) { std::unique_lock lk(thread_mutex); @@ -467,11 +472,13 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[op_idx] += time; { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); - std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i - << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << i << "]:OP[" << op->Type() + << "]:START[" << start.tv_sec * 1e6 + start.tv_usec + << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" + << std::endl; } } op_idx++; @@ -479,11 +486,13 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(µ_end, NULL); { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; } } dev_ctx_->Wait(); @@ -522,11 +531,13 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[op_idx] += time; { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); - std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << num_microbatches_ - << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << num_microbatches_ << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } } op_idx++; @@ -534,11 +545,13 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(µ_end, NULL); { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" + << std::endl; } dev_ctx_->Wait(); batch_timer.Pause(); @@ -582,7 +595,8 @@ void SectionWorker::TrainFilesWithProfiler() { lk.unlock(); VLOG(0) << "============timeline============"; for (size_t i = 0; i < ops_.size(); ++i) { - VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i] + VLOG(0) << "op: " << op_name[i] + << ", max_time: " << op_max_time[i] << ", min_time: " << op_min_time[i] << ", mean_time: " << op_total_time[i] / op_count[i]; } @@ -632,11 +646,13 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[op_idx] += time; { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); - std::cout << "::FWD:B[" << local_batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i - << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + std::cout << "::FWD:B[" << local_batch_id_ << "]:SEC[" + << thread_id_ << "]:SCOPE[" << i << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } } op_idx++; @@ -644,11 +660,13 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(µ_end, NULL); { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; } } dev_ctx_->Wait(); @@ -685,11 +703,13 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[op_idx] += time; { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); - std::cout << "::BWD:B[" << local_batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i - << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + std::cout << "::BWD:B[" << local_batch_id_ << "]:SEC[" + << thread_id_ << "]:SCOPE[" << i << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } } op_idx++; @@ -697,11 +717,13 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(µ_end, NULL); { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec + << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec + << "]" << std::endl; } } dev_ctx_->Wait(); @@ -738,11 +760,13 @@ void SectionWorker::TrainFilesWithProfiler() { op_total_time[op_idx] += time; { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); - std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << num_microbatches_ - << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec - << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; + std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ + << "]:SCOPE[" << num_microbatches_ << "]:OP[" + << op->Type() << "]:START[" + << start.tv_sec * 1e6 + start.tv_usec << "]:END[" + << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl; } } op_idx++; @@ -750,11 +774,13 @@ void SectionWorker::TrainFilesWithProfiler() { gettimeofday(µ_end, NULL); { std::unique_lock lk(cout_mutex); - std::cout << std::fixed; + std::cout << std::fixed; std::cout.precision(0); std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ - << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec - << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; + << "]:START[" + << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" + << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" + << std::endl; } dev_ctx_->Wait(); if (local_completed) { diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 39252b0e05a..757d2189904 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -1142,7 +1142,7 @@ class MomentumOptimizer(Optimizer): class DGCMomentumOptimizer(Optimizer): """ - :api_attr: Static Graph + :api_attr: Static Graph DGC (Deep Gradient Compression) Momentum Optimizer. Original paper is https://arxiv.org/abs/1712.01887 @@ -3068,7 +3068,7 @@ Lamb = LambOptimizer class ModelAverage(Optimizer): """ - :api_attr: Static Graph + :api_attr: Static Graph The ModelAverage optimizer accumulates specific continuous historical parameters during training. The accumulated historical range can be controlled by the passed @@ -3377,7 +3377,7 @@ class ModelAverage(Optimizer): class ExponentialMovingAverage(object): """ - :api_attr: Static Graph + :api_attr: Static Graph Compute the moving average of parameters with exponential decay. Given a parameter :math:`\\theta`, its exponential moving average (EMA) @@ -3627,7 +3627,7 @@ class ExponentialMovingAverage(object): class PipelineOptimizer(object): """ - :api_attr: Static Graph + :api_attr: Static Graph Pipeline Optimizer: Make a program to run as pipeline, that is splitting a program into multiple sections (sub-programs) and each section run on a @@ -4132,7 +4132,7 @@ class PipelineOptimizer(object): index=0, type='fill_constant', inputs={}, - outputs={'Out':[grad_var]}, + outputs={'Out': [grad_var]}, attrs={ 'shape': grad_var.shape, 'dtype': grad_var.dtype, @@ -4326,6 +4326,7 @@ class PipelineOptimizer(object): # attribute have not been set yet. Then check all ops have the # op_device attribute. self._add_default_opdevice_attr(main_block) + device_specs = self._check_validation(main_block) # Step3: add enqueue and dequeue ops between section boundaries @@ -4388,7 +4389,7 @@ class PipelineOptimizer(object): class RecomputeOptimizer(Optimizer): """ - :api_attr: Static Graph + :api_attr: Static Graph Recompute Optimizer Wrapper @@ -4473,7 +4474,7 @@ class RecomputeOptimizer(Optimizer): def load(self, stat_dict): """ - :api_attr: Static Graph + :api_attr: Static Graph load function is not supported by Recompute Optimizer for now. :return: None @@ -4697,7 +4698,7 @@ class RecomputeOptimizer(Optimizer): class LookaheadOptimizer(object): """ - :api_attr: Static Graph + :api_attr: Static Graph This implements the Lookahead optimizer of the paper : https://arxiv.org/abs/1907.08610. -- GitLab