diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index e5c04db8369dd4748893b2b568b7cbd3bcb119ac..bbf359339008105516fdcd648b685292e5cb2711 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -194,10 +194,17 @@ void SectionWorker::TrainFiles() { dev_ctx_->Wait(); batch_timer.Pause(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + { + std::unique_lock lk(thread_mutex); + if (threads_completed) { + return; + } + } } } else { while (true) { // forward pass: + bool local_completed = false; int real_microbatch_num = 0; for (int i = 0; i < num_microbatches_; ++i) { { @@ -217,6 +224,7 @@ void SectionWorker::TrainFiles() { VLOG(3) << "thread " << thread_id_ << " completed."; lk.unlock(); threads_completed = false; + local_completed = true; break; } lk.unlock(); @@ -282,6 +290,9 @@ void SectionWorker::TrainFiles() { } } dev_ctx_->Wait(); + if (local_completed) { + return; + } } } } @@ -479,6 +490,7 @@ void SectionWorker::TrainFilesWithProfiler() { if (real_microbatch_num == 0) { batch_timer.Pause(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + return; } // update pass int op_idx = 0; @@ -528,14 +540,15 @@ void SectionWorker::TrainFilesWithProfiler() { << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; } - struct timeval wait_start; - struct timeval wait_end; - gettimeofday(&wait_start, NULL); dev_ctx_->Wait(); - gettimeofday(&wait_end, NULL); - VLOG(0) << "device wait: " << wait_end.tv_sec * 1e6 + wait_end.tv_usec - wait_start.tv_sec * 1e6 - wait_start.tv_usec; batch_timer.Pause(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + { + std::unique_lock lk(thread_mutex); + if (threads_completed) { + return; + } + } } } else { struct timeval start; @@ -545,6 +558,7 @@ void SectionWorker::TrainFilesWithProfiler() { cudaEvent_t cu_start, cu_stop; cudaEventCreate(&cu_start); cudaEventCreate(&cu_stop); + bool local_completed = false; while (true) { // forward pass: int real_microbatch_num = 0; @@ -563,6 +577,7 @@ void SectionWorker::TrainFilesWithProfiler() { VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " << local_batch_id_ << " batch_id_ " << batch_id_; if (threads_completed) { + local_completed = true; VLOG(3) << "thread " << thread_id_ << " completed."; lk.unlock(); VLOG(0) << "============timeline============"; @@ -742,6 +757,9 @@ void SectionWorker::TrainFilesWithProfiler() { << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; } dev_ctx_->Wait(); + if (local_completed) { + return; + } } } } diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 5ecda645474573fd3ff3926e2708ebe4652db47f..cb5e5c815052420cfb4afa0297fb26cff0c9c4ee 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3772,30 +3772,31 @@ class PipelineOptimizer(object): return programs - #def _find_post_op(self, ops, cur_op, var_name): - # """ - # Find the real post op that has variable named var_name as input. - - # Args: - # ops (list): A list of ops. - # cur_op (Operator): Current operator which has variable named - # var_name as output. - # var_name (string): Variable name. - # """ - # post_op = [] - # before = True - # for op in ops: - # if op == cur_op: - # before = False - # continue - # if before: - # continue - # for in_var_name in op.input_arg_names: - # if in_var_name == var_name: - # post_op.append(op) - # if post_op: - # return post_op[0] - # return None + def _find_post_op(self, ops, cur_op, var_name): + """ + Find the real post op that has variable named var_name as input. + + Args: + ops (list): A list of ops. + cur_op (Operator): Current operator which has variable named + var_name as output. + var_name (string): Variable name. + """ + post_op = [] + before = True + for op in ops: + if op == cur_op: + before = False + continue + if before: + continue + for in_var_name in op.input_arg_names: + if in_var_name == var_name: + post_op.append(op) + break + if post_op: + return post_op[0] + return None def _find_real_prev_op(self, ops, cur_op, var_name): """ @@ -4009,12 +4010,8 @@ class PipelineOptimizer(object): assert '@RENAME@' in name assert len(op.desc.output_arg_names()) == 1 out_name = op.desc.output_arg_names()[0] - assert core.grad_var_suffix() in out_name - param_name = self._strip_grad_suffix(out_name) - assert param_name in self._param_device_map - device = self._param_device_map[param_name] - #post_op = self._find_post_op(block.ops, op, out_name) - #device = post_op.attr(self._op_device_key) + post_op = self._find_post_op(block.ops, op, out_name) + device = post_op.attr(self._op_device_key) assert device op._set_attr(self._op_device_key, device) continue