未验证 提交 d3367e13 编写于 作者: W wangguanzhong 提交者: GitHub

simplify log of loading weights (#2674)

上级 e3049ab2
...@@ -188,7 +188,7 @@ The relationship between COCO mAP and FPS on Tesla V100 of representative models ...@@ -188,7 +188,7 @@ The relationship between COCO mAP and FPS on Tesla V100 of representative models
- `PP-YOLO` achieves mAP of 45.9% on COCO and 72.9FPS on Tesla V100. Both precision and speed surpass [YOLOv4](https://arxiv.org/abs/2004.10934) - `PP-YOLO` achieves mAP of 45.9% on COCO and 72.9FPS on Tesla V100. Both precision and speed surpass [YOLOv4](https://arxiv.org/abs/2004.10934)
- `PP-YOLO v2` is optimized version of `PP-YOLO` which has mAP of 49.5% and 60FPS on Tesla V100 - `PP-YOLO v2` is optimized version of `PP-YOLO` which has mAP of 49.5% and 68.9FPS on Tesla V100
- All these models can be get in [Model Zoo](#ModelZoo) - All these models can be get in [Model Zoo](#ModelZoo)
......
...@@ -157,28 +157,21 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -157,28 +157,21 @@ def load_pretrain_weight(model, pretrain_weight):
weights_path = path + '.pdparams' weights_path = path + '.pdparams'
param_state_dict = paddle.load(weights_path) param_state_dict = paddle.load(weights_path)
lack_backbone_weights_cnt = 0 ignore_weights = set()
lack_modules = set()
for name, weight in model_dict.items(): for name, weight in param_state_dict.items():
if name in param_state_dict.keys(): if name in model_dict.keys():
if weight.shape != list(param_state_dict[name].shape): if list(weight.shape) != list(model_dict[name].shape):
logger.info( logger.info(
'{} not used, shape {} unmatched with {} in model.'.format( '{} not used, shape {} unmatched with {} in model.'.format(
name, list(param_state_dict[name].shape), weight.shape)) name, weight.shape, list(model_dict[name].shape)))
param_state_dict.pop(name, None) ignore_weights.add(name)
else: else:
lack_modules.add(name.split('.')[0]) logger.info('Redundant weight {} and ignore it.'.format(name))
if name.find('backbone') >= 0: ignore_weights.add(name)
logger.info('Lack backbone weights: {}'.format(name))
lack_backbone_weights_cnt += 1 for weight in ignore_weights:
param_state_dict.pop(weight, None)
if lack_backbone_weights_cnt > 0:
logger.info('Lack {} weights in backbone.'.format(
lack_backbone_weights_cnt))
if len(lack_modules) > 0:
logger.info('Lack weights of modules: {}'.format(', '.join(
list(lack_modules))))
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
logger.info('Finish loading model weights: {}'.format(weights_path)) logger.info('Finish loading model weights: {}'.format(weights_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册