diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index b5aa8469734433e49e4ea0d44542fe27f1cd2fc1..25d021a9b1b4f2e10beb40e718a6f307b52c9a0f 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -124,7 +124,7 @@ def match_state_dict(model_state_dict, weight_state_dict): weight_keys = sorted(weight_state_dict.keys()) def match(a, b): - if a.startswith('backbone.res5'): + if b.startswith('backbone.res5'): # In Faster RCNN, res5 pretrained weights have prefix of backbone, # however, the corresponding model weights have difficult prefix, # bbox_head. @@ -139,10 +139,14 @@ def match_state_dict(model_state_dict, weight_state_dict): max_id = match_matrix.argmax(1) max_len = match_matrix.max(1) max_id[max_len == 0] = -1 + + load_id = set(max_id) + load_id.discard(-1) not_load_weight_name = [] - for match_idx in range(len(max_id)): - if match_idx < len(weight_keys) and max_id[match_idx] == -1: - not_load_weight_name.append(weight_keys[match_idx]) + for idx in range(len(weight_keys)): + if idx not in load_id: + not_load_weight_name.append(weight_keys[idx]) + if len(not_load_weight_name) > 0: logger.info('{} in pretrained weight is not used in the model, ' 'and its will not be loaded'.format(not_load_weight_name))