From 82646ce05c3468e80c7e6453bd3de704897b7b0c Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 15 Dec 2021 14:15:52 +0800 Subject: [PATCH] fix checkpoint log when weight is not loaded (#4885) * fix checkpoint log when weight not loaded * fix loading pretrain weight in mask rcnn * discard unmatch id instead of remove --- ppdet/utils/checkpoint.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index b5aa84697..25d021a9b 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -124,7 +124,7 @@ def match_state_dict(model_state_dict, weight_state_dict): weight_keys = sorted(weight_state_dict.keys()) def match(a, b): - if a.startswith('backbone.res5'): + if b.startswith('backbone.res5'): # In Faster RCNN, res5 pretrained weights have prefix of backbone, # however, the corresponding model weights have difficult prefix, # bbox_head. @@ -139,10 +139,14 @@ def match_state_dict(model_state_dict, weight_state_dict): max_id = match_matrix.argmax(1) max_len = match_matrix.max(1) max_id[max_len == 0] = -1 + + load_id = set(max_id) + load_id.discard(-1) not_load_weight_name = [] - for match_idx in range(len(max_id)): - if match_idx < len(weight_keys) and max_id[match_idx] == -1: - not_load_weight_name.append(weight_keys[match_idx]) + for idx in range(len(weight_keys)): + if idx not in load_id: + not_load_weight_name.append(weight_keys[idx]) + if len(not_load_weight_name) > 0: logger.info('{} in pretrained weight is not used in the model, ' 'and its will not be loaded'.format(not_load_weight_name)) -- GitLab