未验证 提交 ba9a787d 编写于 作者: W wanghuancoder 提交者: GitHub

fix resnet usetime bug (#4869)

上级 e07327e8
......@@ -34,6 +34,24 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt
def build_program(is_train, main_prog, startup_prog, args):
"""build program, and add backward op in program accroding to different mode
......@@ -225,7 +243,11 @@ def train(args):
compiled_train_prog = best_strategy_compiled(args, train_prog,
train_fetch_vars[0], exe)
batch_cost_avg = TimeCostAverage()
#NOTE: this for benchmark
total_batch_num = 0
for pass_id in range(args.num_epochs):
if num_trainers > 1 and not args.use_dali:
......@@ -234,7 +256,6 @@ def train(args):
train_batch_id = 0
train_batch_time_record = []
train_batch_metrics_record = []
train_batch_time_print_step = []
if not args.use_dali:
train_iter = train_data_loader()
......@@ -252,25 +273,18 @@ def train(args):
t2 = time.time()
train_batch_elapse = t2 - t1
train_batch_time_record.append(train_batch_elapse)
batch_cost_avg.record(train_batch_elapse)
train_batch_metrics_avg = np.mean(
np.array(train_batch_metrics), axis=1)
train_batch_metrics_record.append(train_batch_metrics_avg)
if trainer_id == 0:
if train_batch_id % args.print_step == 0:
if len(train_batch_time_print_step) == 0:
train_batch_time_print_step_avg = train_batch_elapse
else:
train_batch_time_print_step_avg = np.mean(
train_batch_time_print_step)
train_batch_time_print_step = []
print_info("batch", train_batch_metrics_avg,
train_batch_time_print_step_avg, pass_id,
train_batch_id, args.print_step)
else:
train_batch_time_print_step.append(train_batch_elapse)
print_info("batch", train_batch_metrics_avg,
batch_cost_avg.get_average(), pass_id,
train_batch_id, args.print_step)
sys.stdout.flush()
if train_batch_id % args.print_step == 0:
batch_cost_avg.reset()
train_batch_id += 1
t1 = time.time()
#NOTE: this for benchmark profiler
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册