未验证 提交 5e9fc1ff 编写于 作者: G Guanghua Yu 提交者: GitHub

fix reader of print_flops (#4960)

上级 4b95e7a2
...@@ -364,7 +364,9 @@ class Trainer(object): ...@@ -364,7 +364,9 @@ class Trainer(object):
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False): if self.cfg.get('print_flops', False):
self._flops(self.loader) flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num)
self._flops(flops_loader)
profiler_options = self.cfg.get('profiler_options', None) profiler_options = self.cfg.get('profiler_options', None)
self._compose_callback.on_train_begin(self.status) self._compose_callback.on_train_begin(self.status)
...@@ -463,7 +465,9 @@ class Trainer(object): ...@@ -463,7 +465,9 @@ class Trainer(object):
self.status['mode'] = 'eval' self.status['mode'] = 'eval'
self.model.eval() self.model.eval()
if self.cfg.get('print_flops', False): if self.cfg.get('print_flops', False):
self._flops(loader) flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
self._flops(flops_loader)
for step_id, data in enumerate(loader): for step_id, data in enumerate(loader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
...@@ -510,7 +514,8 @@ class Trainer(object): ...@@ -510,7 +514,8 @@ class Trainer(object):
self.status['mode'] = 'test' self.status['mode'] = 'test'
self.model.eval() self.model.eval()
if self.cfg.get('print_flops', False): if self.cfg.get('print_flops', False):
self._flops(loader) flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = [] results = []
for step_id, data in enumerate(loader): for step_id, data in enumerate(loader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册