提交 0360e583 编写于 作者: S sandyhouse

bug fix, test=develop

上级 63b58dc2
......@@ -194,10 +194,17 @@ void SectionWorker::TrainFiles() {
dev_ctx_->Wait();
batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
{
std::unique_lock<std::mutex> lk(thread_mutex);
if (threads_completed) {
return;
}
}
}
} else {
while (true) {
// forward pass:
bool local_completed = false;
int real_microbatch_num = 0;
for (int i = 0; i < num_microbatches_; ++i) {
{
......@@ -217,6 +224,7 @@ void SectionWorker::TrainFiles() {
VLOG(3) << "thread " << thread_id_ << " completed.";
lk.unlock();
threads_completed = false;
local_completed = true;
break;
}
lk.unlock();
......@@ -282,6 +290,9 @@ void SectionWorker::TrainFiles() {
}
}
dev_ctx_->Wait();
if (local_completed) {
return;
}
}
}
}
......@@ -479,6 +490,7 @@ void SectionWorker::TrainFilesWithProfiler() {
if (real_microbatch_num == 0) {
batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
return;
}
// update pass
int op_idx = 0;
......@@ -528,14 +540,15 @@ void SectionWorker::TrainFilesWithProfiler() {
<< "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
<< "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
}
struct timeval wait_start;
struct timeval wait_end;
gettimeofday(&wait_start, NULL);
dev_ctx_->Wait();
gettimeofday(&wait_end, NULL);
VLOG(0) << "device wait: " << wait_end.tv_sec * 1e6 + wait_end.tv_usec - wait_start.tv_sec * 1e6 - wait_start.tv_usec;
batch_timer.Pause();
VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
{
std::unique_lock<std::mutex> lk(thread_mutex);
if (threads_completed) {
return;
}
}
}
} else {
struct timeval start;
......@@ -545,6 +558,7 @@ void SectionWorker::TrainFilesWithProfiler() {
cudaEvent_t cu_start, cu_stop;
cudaEventCreate(&cu_start);
cudaEventCreate(&cu_stop);
bool local_completed = false;
while (true) {
// forward pass:
int real_microbatch_num = 0;
......@@ -563,6 +577,7 @@ void SectionWorker::TrainFilesWithProfiler() {
VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
<< local_batch_id_ << " batch_id_ " << batch_id_;
if (threads_completed) {
local_completed = true;
VLOG(3) << "thread " << thread_id_ << " completed.";
lk.unlock();
VLOG(0) << "============timeline============";
......@@ -742,6 +757,9 @@ void SectionWorker::TrainFilesWithProfiler() {
<< "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
}
dev_ctx_->Wait();
if (local_completed) {
return;
}
}
}
}
......
......@@ -3772,30 +3772,31 @@ class PipelineOptimizer(object):
return programs
#def _find_post_op(self, ops, cur_op, var_name):
# """
# Find the real post op that has variable named var_name as input.
# Args:
# ops (list): A list of ops.
# cur_op (Operator): Current operator which has variable named
# var_name as output.
# var_name (string): Variable name.
# """
# post_op = []
# before = True
# for op in ops:
# if op == cur_op:
# before = False
# continue
# if before:
# continue
# for in_var_name in op.input_arg_names:
# if in_var_name == var_name:
# post_op.append(op)
# if post_op:
# return post_op[0]
# return None
def _find_post_op(self, ops, cur_op, var_name):
"""
Find the real post op that has variable named var_name as input.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has variable named
var_name as output.
var_name (string): Variable name.
"""
post_op = []
before = True
for op in ops:
if op == cur_op:
before = False
continue
if before:
continue
for in_var_name in op.input_arg_names:
if in_var_name == var_name:
post_op.append(op)
break
if post_op:
return post_op[0]
return None
def _find_real_prev_op(self, ops, cur_op, var_name):
"""
......@@ -4009,12 +4010,8 @@ class PipelineOptimizer(object):
assert '@RENAME@' in name
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
assert core.grad_var_suffix() in out_name
param_name = self._strip_grad_suffix(out_name)
assert param_name in self._param_device_map
device = self._param_device_map[param_name]
#post_op = self._find_post_op(block.ops, op, out_name)
#device = post_op.attr(self._op_device_key)
post_op = self._find_post_op(block.ops, op, out_name)
device = post_op.attr(self._op_device_key)
assert device
op._set_attr(self._op_device_key, device)
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册