未验证 提交 fc7b5d22 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #1181 from WenmuZhou/dygraph_rc

日志符合benckmark规范
...@@ -65,5 +65,8 @@ class TrainingStats(object): ...@@ -65,5 +65,8 @@ class TrainingStats(object):
def log(self, extras=None): def log(self, extras=None):
d = self.get(extras) d = self.get(extras)
strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items()) strs = []
for k, v in d.items():
strs.append('{}: {:x<6f}'.format(k, v))
strs = ', '.join(strs)
return strs return strs
...@@ -185,12 +185,15 @@ def train(config, ...@@ -185,12 +185,15 @@ def train(config,
for epoch in range(start_epoch, epoch_num): for epoch in range(start_epoch, epoch_num):
if epoch > 0: if epoch > 0:
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader): if idx >= len(train_dataloader):
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
t1 = time.time()
images = batch[0] images = batch[0]
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -198,6 +201,10 @@ def train(config, ...@@ -198,6 +201,10 @@ def train(config,
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
train_batch_cost += time.time() - batch_start
batch_sum += len(images)
if not isinstance(lr_scheduler, float): if not isinstance(lr_scheduler, float):
lr_scheduler.step() lr_scheduler.step()
...@@ -213,9 +220,6 @@ def train(config, ...@@ -213,9 +220,6 @@ def train(config,
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
train_stats.update(metirc) train_stats.update(metirc)
t2 = time.time()
train_batch_elapse = t2 - t1
if vdl_writer is not None and dist.get_rank() == 0: if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items(): for k, v in train_stats.get().items():
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
...@@ -224,9 +228,15 @@ def train(config, ...@@ -224,9 +228,15 @@ def train(config,
if dist.get_rank( if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0: ) == 0 and global_step > 0 and global_step % print_batch_step == 0:
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_batch_elapse) epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step,
batch_sum, batch_sum / train_batch_cost)
logger.info(strs) logger.info(strs)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册