diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index db29f03ea823bd1e9b5e299d687bcbe4133a5b3d..671d373bb13c33d7aa4d4734ad606d62035d7dd8 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -364,7 +364,9 @@ class Trainer(object): self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) if self.cfg.get('print_flops', False): - self._flops(self.loader) + flops_loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, self.cfg.worker_num) + self._flops(flops_loader) profiler_options = self.cfg.get('profiler_options', None) self._compose_callback.on_train_begin(self.status) @@ -463,7 +465,9 @@ class Trainer(object): self.status['mode'] = 'eval' self.model.eval() if self.cfg.get('print_flops', False): - self._flops(loader) + flops_loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, self.cfg.worker_num, self._eval_batch_sampler) + self._flops(flops_loader) for step_id, data in enumerate(loader): self.status['step_id'] = step_id self._compose_callback.on_step_begin(self.status) @@ -510,7 +514,8 @@ class Trainer(object): self.status['mode'] = 'test' self.model.eval() if self.cfg.get('print_flops', False): - self._flops(loader) + flops_loader = create('TestReader')(self.dataset, 0) + self._flops(flops_loader) results = [] for step_id, data in enumerate(loader): self.status['step_id'] = step_id