From befec46347fd5d6698a7e99cbfb5cb2902e1055c Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 8 Jul 2021 15:02:22 +0800 Subject: [PATCH] add flops (#3629) * add flops --- configs/runtime.yml | 1 + docs/tutorials/FAQ.md | 4 ++++ ppdet/engine/trainer.py | 32 ++++++++++++++++++++++++++++++++ ppdet/slim/prune.py | 4 ++-- requirements.txt | 2 +- 5 files changed, 40 insertions(+), 3 deletions(-) diff --git a/configs/runtime.yml b/configs/runtime.yml index deac1cb05..c502ddabe 100644 --- a/configs/runtime.yml +++ b/configs/runtime.yml @@ -2,3 +2,4 @@ use_gpu: true log_iter: 20 save_dir: output snapshot_epoch: 1 +print_flops: false diff --git a/docs/tutorials/FAQ.md b/docs/tutorials/FAQ.md index a9b289343..0994a727d 100644 --- a/docs/tutorials/FAQ.md +++ b/docs/tutorials/FAQ.md @@ -94,3 +94,7 @@ TestDataset: !ImageFolder anno_path: annotations/instances_val2017.json ``` + +**Q:** 如何打印网络FLOPs? + +**A:** 在`configs/runtime.yml`中设置`print_flops: true`,同时需要安装PaddleSlim(比如:pip install paddleslim),即可打印模型的FLOPs。 diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 0cf2ae683..53e82792a 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -325,6 +325,9 @@ class Trainer(object): self.cfg.log_iter, fmt='{avg:.4f}') self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) + if self.cfg.get('print_flops', False): + self._flops(self.loader) + for epoch_id in range(self.start_epoch, self.cfg.epoch): self.status['mode'] = 'train' self.status['epoch_id'] = epoch_id @@ -405,6 +408,8 @@ class Trainer(object): self._compose_callback.on_epoch_begin(self.status) self.status['mode'] = 'eval' self.model.eval() + if self.cfg.get('print_flops', False): + self._flops(loader) for step_id, data in enumerate(loader): self.status['step_id'] = step_id self._compose_callback.on_step_begin(self.status) @@ -450,6 +455,8 @@ class Trainer(object): # Run Infer self.status['mode'] = 'test' self.model.eval() + if self.cfg.get('print_flops', False): + self._flops(loader) for step_id, data in enumerate(loader): self.status['step_id'] = step_id # forward @@ -587,3 +594,28 @@ class Trainer(object): pass paddle.disable_static() return pruned_input_spec + + def _flops(self, loader): + self.model.eval() + try: + import paddleslim + except Exception as e: + logger.warning( + 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`' + ) + return + + from paddleslim.analysis import dygraph_flops as flops + input_data = None + for data in loader: + input_data = data + break + + input_spec = [{ + "image": input_data['image'][0].unsqueeze(0), + "im_shape": input_data['im_shape'][0].unsqueeze(0), + "scale_factor": input_data['scale_factor'][0].unsqueeze(0) + }] + flops = flops(self.model, input_spec) / (1000**3) + logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format( + flops, input_data['image'][0].unsqueeze(0).shape)) diff --git a/ppdet/slim/prune.py b/ppdet/slim/prune.py index 2d01e30a8..70d3de369 100644 --- a/ppdet/slim/prune.py +++ b/ppdet/slim/prune.py @@ -65,7 +65,7 @@ class Pruner(object): if self.print_params: print_prune_params(model) - ori_flops = flops(model, input_spec) / 1000 + ori_flops = flops(model, input_spec) / (1000**3) logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops)) if self.criterion == 'fpgm': pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec) @@ -78,7 +78,7 @@ class Pruner(object): for i, param in enumerate(self.pruned_params): ratios[param] = pruned_ratios[i] pruner.prune_vars(ratios, [0]) - pruned_flops = flops(model, input_spec) / 1000 + pruned_flops = flops(model, input_spec) / (1000**3) logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format( pruned_flops, (ori_flops - pruned_flops) / ori_flops)) diff --git a/requirements.txt b/requirements.txt index d8864d8da..7ac38c2e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ lap sklearn motmetrics openpyxl -decord \ No newline at end of file +decord -- GitLab