未验证 提交 1f8f8539 编写于 作者: G Guanghua Yu 提交者: GitHub

fix reader of print_flops (#4959)

上级 ea2f81d8
...@@ -370,7 +370,9 @@ class Trainer(object): ...@@ -370,7 +370,9 @@ class Trainer(object):
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False): if self.cfg.get('print_flops', False):
self._flops(self.loader) flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num)
self._flops(flops_loader)
profiler_options = self.cfg.get('profiler_options', None) profiler_options = self.cfg.get('profiler_options', None)
self._compose_callback.on_train_begin(self.status) self._compose_callback.on_train_begin(self.status)
...@@ -469,7 +471,9 @@ class Trainer(object): ...@@ -469,7 +471,9 @@ class Trainer(object):
self.status['mode'] = 'eval' self.status['mode'] = 'eval'
self.model.eval() self.model.eval()
if self.cfg.get('print_flops', False): if self.cfg.get('print_flops', False):
self._flops(loader) flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
for step_id, data in enumerate(loader): for step_id, data in enumerate(loader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
...@@ -520,7 +524,8 @@ class Trainer(object): ...@@ -520,7 +524,8 @@ class Trainer(object):
self.status['mode'] = 'test' self.status['mode'] = 'test'
self.model.eval() self.model.eval()
if self.cfg.get('print_flops', False): if self.cfg.get('print_flops', False):
self._flops(loader) flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = [] results = []
for step_id, data in enumerate(loader): for step_id, data in enumerate(loader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册