From 5e9fc1ffc45821d1344048006f06ab96276d91e6 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 20 Dec 2021 22:59:18 +0800 Subject: [PATCH] fix reader of print_flops (#4960) --- ppdet/engine/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index db29f03ea..671d373bb 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -364,7 +364,9 @@ class Trainer(object): self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) if self.cfg.get('print_flops', False): - self._flops(self.loader) + flops_loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, self.cfg.worker_num) + self._flops(flops_loader) profiler_options = self.cfg.get('profiler_options', None) self._compose_callback.on_train_begin(self.status) @@ -463,7 +465,9 @@ class Trainer(object): self.status['mode'] = 'eval' self.model.eval() if self.cfg.get('print_flops', False): - self._flops(loader) + flops_loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, self.cfg.worker_num, self._eval_batch_sampler) + self._flops(flops_loader) for step_id, data in enumerate(loader): self.status['step_id'] = step_id self._compose_callback.on_step_begin(self.status) @@ -510,7 +514,8 @@ class Trainer(object): self.status['mode'] = 'test' self.model.eval() if self.cfg.get('print_flops', False): - self._flops(loader) + flops_loader = create('TestReader')(self.dataset, 0) + self._flops(flops_loader) results = [] for step_id, data in enumerate(loader): self.status['step_id'] = step_id -- GitLab