未验证 提交 5e9fc1ff 编写于 作者: G Guanghua Yu 提交者: GitHub

fix reader of print_flops (#4960)

上级 4b95e7a2
......@@ -364,7 +364,9 @@ class Trainer(object):
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False):
self._flops(self.loader)
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num)
self._flops(flops_loader)
profiler_options = self.cfg.get('profiler_options', None)
self._compose_callback.on_train_begin(self.status)
......@@ -463,7 +465,9 @@ class Trainer(object):
self.status['mode'] = 'eval'
self.model.eval()
if self.cfg.get('print_flops', False):
self._flops(loader)
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
......@@ -510,7 +514,8 @@ class Trainer(object):
self.status['mode'] = 'test'
self.model.eval()
if self.cfg.get('print_flops', False):
self._flops(loader)
flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = []
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册