未验证 提交 befec463 编写于 作者: G Guanghua Yu 提交者: GitHub

add flops (#3629)

* add flops
上级 1f6087ad
......@@ -2,3 +2,4 @@ use_gpu: true
log_iter: 20
save_dir: output
snapshot_epoch: 1
print_flops: false
......@@ -94,3 +94,7 @@ TestDataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
```
**Q:** 如何打印网络FLOPs?
**A:**`configs/runtime.yml`中设置`print_flops: true`,同时需要安装PaddleSlim(比如:pip install paddleslim),即可打印模型的FLOPs。
......@@ -325,6 +325,9 @@ class Trainer(object):
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False):
self._flops(self.loader)
for epoch_id in range(self.start_epoch, self.cfg.epoch):
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id
......@@ -405,6 +408,8 @@ class Trainer(object):
self._compose_callback.on_epoch_begin(self.status)
self.status['mode'] = 'eval'
self.model.eval()
if self.cfg.get('print_flops', False):
self._flops(loader)
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
......@@ -450,6 +455,8 @@ class Trainer(object):
# Run Infer
self.status['mode'] = 'test'
self.model.eval()
if self.cfg.get('print_flops', False):
self._flops(loader)
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
# forward
......@@ -587,3 +594,28 @@ class Trainer(object):
pass
paddle.disable_static()
return pruned_input_spec
def _flops(self, loader):
self.model.eval()
try:
import paddleslim
except Exception as e:
logger.warning(
'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
)
return
from paddleslim.analysis import dygraph_flops as flops
input_data = None
for data in loader:
input_data = data
break
input_spec = [{
"image": input_data['image'][0].unsqueeze(0),
"im_shape": input_data['im_shape'][0].unsqueeze(0),
"scale_factor": input_data['scale_factor'][0].unsqueeze(0)
}]
flops = flops(self.model, input_spec) / (1000**3)
logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
flops, input_data['image'][0].unsqueeze(0).shape))
......@@ -65,7 +65,7 @@ class Pruner(object):
if self.print_params:
print_prune_params(model)
ori_flops = flops(model, input_spec) / 1000
ori_flops = flops(model, input_spec) / (1000**3)
logger.info("FLOPs before pruning: {}GFLOPs".format(ori_flops))
if self.criterion == 'fpgm':
pruner = paddleslim.dygraph.FPGMFilterPruner(model, input_spec)
......@@ -78,7 +78,7 @@ class Pruner(object):
for i, param in enumerate(self.pruned_params):
ratios[param] = pruned_ratios[i]
pruner.prune_vars(ratios, [0])
pruned_flops = flops(model, input_spec) / 1000
pruned_flops = flops(model, input_spec) / (1000**3)
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
pruned_flops, (ori_flops - pruned_flops) / ori_flops))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册