From d3367e1369c7ca6296a18c8eeda89ada5162a483 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Sat, 17 Apr 2021 00:43:05 +0800 Subject: [PATCH] simplify log of loading weights (#2674) --- README_en.md | 2 +- ppdet/utils/checkpoint.py | 31 ++++++++++++------------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/README_en.md b/README_en.md index 72fd84c37..301bfa323 100644 --- a/README_en.md +++ b/README_en.md @@ -188,7 +188,7 @@ The relationship between COCO mAP and FPS on Tesla V100 of representative models - `PP-YOLO` achieves mAP of 45.9% on COCO and 72.9FPS on Tesla V100. Both precision and speed surpass [YOLOv4](https://arxiv.org/abs/2004.10934) -- `PP-YOLO v2` is optimized version of `PP-YOLO` which has mAP of 49.5% and 60FPS on Tesla V100 +- `PP-YOLO v2` is optimized version of `PP-YOLO` which has mAP of 49.5% and 68.9FPS on Tesla V100 - All these models can be get in [Model Zoo](#ModelZoo) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 998a0747c..d4f08097f 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -157,28 +157,21 @@ def load_pretrain_weight(model, pretrain_weight): weights_path = path + '.pdparams' param_state_dict = paddle.load(weights_path) - lack_backbone_weights_cnt = 0 - lack_modules = set() - for name, weight in model_dict.items(): - if name in param_state_dict.keys(): - if weight.shape != list(param_state_dict[name].shape): + ignore_weights = set() + + for name, weight in param_state_dict.items(): + if name in model_dict.keys(): + if list(weight.shape) != list(model_dict[name].shape): logger.info( '{} not used, shape {} unmatched with {} in model.'.format( - name, list(param_state_dict[name].shape), weight.shape)) - param_state_dict.pop(name, None) + name, weight.shape, list(model_dict[name].shape))) + ignore_weights.add(name) else: - lack_modules.add(name.split('.')[0]) - if name.find('backbone') >= 0: - logger.info('Lack backbone weights: {}'.format(name)) - lack_backbone_weights_cnt += 1 - - if lack_backbone_weights_cnt > 0: - logger.info('Lack {} weights in backbone.'.format( - lack_backbone_weights_cnt)) - - if len(lack_modules) > 0: - logger.info('Lack weights of modules: {}'.format(', '.join( - list(lack_modules)))) + logger.info('Redundant weight {} and ignore it.'.format(name)) + ignore_weights.add(name) + + for weight in ignore_weights: + param_state_dict.pop(weight, None) model.set_dict(param_state_dict) logger.info('Finish loading model weights: {}'.format(weights_path)) -- GitLab