提交 0360e583 编写于 作者: S sandyhouse

bug fix, test=develop

上级 63b58dc2
...@@ -194,10 +194,17 @@ void SectionWorker::TrainFiles() { ...@@ -194,10 +194,17 @@ void SectionWorker::TrainFiles() {
dev_ctx_->Wait(); dev_ctx_->Wait();
batch_timer.Pause(); batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
{
std::unique_lock<std::mutex> lk(thread_mutex);
if (threads_completed) {
return;
}
}
} }
} else { } else {
while (true) { while (true) {
// forward pass: // forward pass:
bool local_completed = false;
int real_microbatch_num = 0; int real_microbatch_num = 0;
for (int i = 0; i < num_microbatches_; ++i) { for (int i = 0; i < num_microbatches_; ++i) {
{ {
...@@ -217,6 +224,7 @@ void SectionWorker::TrainFiles() { ...@@ -217,6 +224,7 @@ void SectionWorker::TrainFiles() {
VLOG(3) << "thread " << thread_id_ << " completed."; VLOG(3) << "thread " << thread_id_ << " completed.";
lk.unlock(); lk.unlock();
threads_completed = false; threads_completed = false;
local_completed = true;
break; break;
} }
lk.unlock(); lk.unlock();
...@@ -282,6 +290,9 @@ void SectionWorker::TrainFiles() { ...@@ -282,6 +290,9 @@ void SectionWorker::TrainFiles() {
} }
} }
dev_ctx_->Wait(); dev_ctx_->Wait();
if (local_completed) {
return;
}
} }
} }
} }
...@@ -479,6 +490,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -479,6 +490,7 @@ void SectionWorker::TrainFilesWithProfiler() {
if (real_microbatch_num == 0) { if (real_microbatch_num == 0) {
batch_timer.Pause(); batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
return;
} }
// update pass // update pass
int op_idx = 0; int op_idx = 0;
...@@ -528,14 +540,15 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -528,14 +540,15 @@ void SectionWorker::TrainFilesWithProfiler() {
<< "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
<< "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
} }
struct timeval wait_start;
struct timeval wait_end;
gettimeofday(&wait_start, NULL);
dev_ctx_->Wait(); dev_ctx_->Wait();
gettimeofday(&wait_end, NULL);
VLOG(0) << "device wait: " << wait_end.tv_sec * 1e6 + wait_end.tv_usec - wait_start.tv_sec * 1e6 - wait_start.tv_usec;
batch_timer.Pause(); batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
{
std::unique_lock<std::mutex> lk(thread_mutex);
if (threads_completed) {
return;
}
}
} }
} else { } else {
struct timeval start; struct timeval start;
...@@ -545,6 +558,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -545,6 +558,7 @@ void SectionWorker::TrainFilesWithProfiler() {
cudaEvent_t cu_start, cu_stop; cudaEvent_t cu_start, cu_stop;
cudaEventCreate(&cu_start); cudaEventCreate(&cu_start);
cudaEventCreate(&cu_stop); cudaEventCreate(&cu_stop);
bool local_completed = false;
while (true) { while (true) {
// forward pass: // forward pass:
int real_microbatch_num = 0; int real_microbatch_num = 0;
...@@ -563,6 +577,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -563,6 +577,7 @@ void SectionWorker::TrainFilesWithProfiler() {
VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
<< local_batch_id_ << " batch_id_ " << batch_id_; << local_batch_id_ << " batch_id_ " << batch_id_;
if (threads_completed) { if (threads_completed) {
local_completed = true;
VLOG(3) << "thread " << thread_id_ << " completed."; VLOG(3) << "thread " << thread_id_ << " completed.";
lk.unlock(); lk.unlock();
VLOG(0) << "============timeline============"; VLOG(0) << "============timeline============";
...@@ -742,6 +757,9 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -742,6 +757,9 @@ void SectionWorker::TrainFilesWithProfiler() {
<< "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl; << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
} }
dev_ctx_->Wait(); dev_ctx_->Wait();
if (local_completed) {
return;
}
} }
} }
} }
......
...@@ -3772,30 +3772,31 @@ class PipelineOptimizer(object): ...@@ -3772,30 +3772,31 @@ class PipelineOptimizer(object):
return programs return programs
#def _find_post_op(self, ops, cur_op, var_name): def _find_post_op(self, ops, cur_op, var_name):
# """ """
# Find the real post op that has variable named var_name as input. Find the real post op that has variable named var_name as input.
# Args: Args:
# ops (list): A list of ops. ops (list): A list of ops.
# cur_op (Operator): Current operator which has variable named cur_op (Operator): Current operator which has variable named
# var_name as output. var_name as output.
# var_name (string): Variable name. var_name (string): Variable name.
# """ """
# post_op = [] post_op = []
# before = True before = True
# for op in ops: for op in ops:
# if op == cur_op: if op == cur_op:
# before = False before = False
# continue continue
# if before: if before:
# continue continue
# for in_var_name in op.input_arg_names: for in_var_name in op.input_arg_names:
# if in_var_name == var_name: if in_var_name == var_name:
# post_op.append(op) post_op.append(op)
# if post_op: break
# return post_op[0] if post_op:
# return None return post_op[0]
return None
def _find_real_prev_op(self, ops, cur_op, var_name): def _find_real_prev_op(self, ops, cur_op, var_name):
""" """
...@@ -4009,12 +4010,8 @@ class PipelineOptimizer(object): ...@@ -4009,12 +4010,8 @@ class PipelineOptimizer(object):
assert '@RENAME@' in name assert '@RENAME@' in name
assert len(op.desc.output_arg_names()) == 1 assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0] out_name = op.desc.output_arg_names()[0]
assert core.grad_var_suffix() in out_name post_op = self._find_post_op(block.ops, op, out_name)
param_name = self._strip_grad_suffix(out_name) device = post_op.attr(self._op_device_key)
assert param_name in self._param_device_map
device = self._param_device_map[param_name]
#post_op = self._find_post_op(block.ops, op, out_name)
#device = post_op.attr(self._op_device_key)
assert device assert device
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册