未验证 提交 fc7b5d22 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #1181 from WenmuZhou/dygraph_rc

日志符合benckmark规范
......@@ -65,5 +65,8 @@ class TrainingStats(object):
def log(self, extras=None):
d = self.get(extras)
strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items())
strs = []
for k, v in d.items():
strs.append('{}: {:x<6f}'.format(k, v))
strs = ', '.join(strs)
return strs
......@@ -185,12 +185,15 @@ def train(config,
for epoch in range(start_epoch, epoch_num):
if epoch > 0:
train_dataloader = build_dataloader(config, 'Train', device, logger)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader):
break
lr = optimizer.get_lr()
t1 = time.time()
images = batch[0]
preds = model(images)
loss = loss_class(preds, batch)
......@@ -198,6 +201,10 @@ def train(config,
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
train_batch_cost += time.time() - batch_start
batch_sum += len(images)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
......@@ -213,9 +220,6 @@ def train(config,
metirc = eval_class.get_metric()
train_stats.update(metirc)
t2 = time.time()
train_batch_elapse = t2 - t1
if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items():
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
......@@ -224,9 +228,15 @@ def train(config,
if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
epoch, epoch_num, global_step, logs, train_batch_elapse)
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step,
batch_sum, batch_sum / train_batch_cost)
logger.info(strs)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册