diff --git a/configs/runtime.yml b/configs/runtime.yml index deac1cb05f9a8b3fd07b370665b8b64f60394016..c502ddabeb93d95a850fe7cb83a1e68ceff3a4e4 100644 --- a/configs/runtime.yml +++ b/configs/runtime.yml @@ -2,3 +2,4 @@ use_gpu: true log_iter: 20 save_dir: output snapshot_epoch: 1 +print_flops: false diff --git a/docs/tutorials/FAQ.md b/docs/tutorials/FAQ.md index a9b289343c4647e863dd4948d9c1329a21c0c87d..0994a727da715f1ee3b472f823de7df1e107ea80 100644 --- a/docs/tutorials/FAQ.md +++ b/docs/tutorials/FAQ.md @@ -94,3 +94,7 @@ TestDataset: !ImageFolder anno_path: annotations/instances_val2017.json ``` + +**Q:** 如何打印网络FLOPs? + +**A:** 在`configs/runtime.yml`中设置`print_flops: true`,同时需要安装PaddleSlim(比如:pip install paddleslim),即可打印模型的FLOPs。 diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 0cf2ae683a42e5e64587aa1b2261b009636eb77c..53e82792ac2dc6e4558dcda7ba827cc675152283 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -325,6 +325,9 @@ class Trainer(object): self.cfg.log_iter, fmt='{avg:.4f}') self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) + if self.cfg.get('print_flops', False): + self._flops(self.loader) + for epoch_id in range(self.start_epoch, self.cfg.epoch): self.status['mode'] = 'train' self.status['epoch_id'] = epoch_id @@ -405,6 +408,8 @@ class Trainer(object): self._compose_callback.on_epoch_begin(self.status) self.status['mode'] = 'eval' self.model.eval() + if self.cfg.get('print_flops', False): + self._flops(loader) for step_id, data in enumerate(loader): self.status['step_id'] = step_id self._compose_callback.on_step_begin(self.status) @@ -450,6 +455,8 @@ class Trainer(object): # Run Infer self.status['mode'] = 'test' self.model.eval() + if self.cfg.get('print_flops', False): + self._flops(loader) for step_id, data in enumerate(loader): self.status['step_id'] = step_id # forward @@ -587,3 +594,28 @@ class Trainer(object): pass paddle.disable_static() return pruned_input_spec + + def _flops(self, loader): + self.model.eval() + try: + import paddleslim + except Exception as e: + logger.warning( + 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`' + ) + return + + from paddleslim.analysis import dygraph_flops as flops + input_data = None + for data in loader: + input_data = data + break + + input_spec = [{ + "image": input_data['image'][0].unsqueeze(0), + "im_shape": input_data['im_shape'][0].unsqueeze(0), + "scale_factor": input_data['scale_factor'][0].unsqueeze(0) + }] + flops = flops(self.model, input_spec) / (1000**3) + logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format( + flops, input_data['image'][0].unsqueeze(0).shape)) diff --git a/ppdet/slim/prune.py b/ppdet/slim/prune.py index 2d01e30a8722ffd4844c470f1e7b346040aa69eb..70d3de3692707b132ff398babf5a795a5a9e81ba 100644 --- a/ppdet/slim/prune.py +++ b/ppdet/slim/prune.py @@ -65,7 +65,7 @@ class Pruner(object): if self.print_params: print_prune_params(model) - ori_flops = flops(model, input_spec) / 1000 + ori_flops = flops(model, input_spec) / (1000**3) logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops)) if self.criterion == 'fpgm': pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec) @@ -78,7 +78,7 @@ class Pruner(object): for i, param in enumerate(self.pruned_params): ratios[param] = pruned_ratios[i] pruner.prune_vars(ratios, [0]) - pruned_flops = flops(model, input_spec) / 1000 + pruned_flops = flops(model, input_spec) / (1000**3) logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format( pruned_flops, (ori_flops - pruned_flops) / ori_flops)) diff --git a/requirements.txt b/requirements.txt index d8864d8da3bf7f922af4f942b379752845a9b0c6..7ac38c2e21ca0faef77760f388e8ab01faaa40cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ lap sklearn motmetrics openpyxl -decord \ No newline at end of file +decord