未验证 提交 1f8f8539 编写于 作者: G Guanghua Yu 提交者: GitHub

fix reader of print_flops (#4959)

上级 ea2f81d8
......@@ -370,7 +370,9 @@ class Trainer(object):
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False):
self._flops(self.loader)
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num)
self._flops(flops_loader)
profiler_options = self.cfg.get('profiler_options', None)
self._compose_callback.on_train_begin(self.status)
......@@ -469,7 +471,9 @@ class Trainer(object):
self.status['mode'] = 'eval'
self.model.eval()
if self.cfg.get('print_flops', False):
self._flops(loader)
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
......@@ -520,7 +524,8 @@ class Trainer(object):
self.status['mode'] = 'test'
self.model.eval()
if self.cfg.get('print_flops', False):
self._flops(loader)
flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = []
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册