diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index ed0433764ba6a13ec62f29298af536ddac3a3b6e..f3dafd40f0f8596c5ad0ac2ad0e99600b0134aab 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -155,7 +155,7 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'): return a == b or a.endswith("." + b) or b.endswith("." + a) def match(a, b): - if a.startswith('backbone.res5'): + if b.startswith('backbone.res5'): b = b[9:] return a == b or a.endswith("." + b) @@ -174,15 +174,28 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'): max_id = match_matrix.argmax(1) max_len = match_matrix.max(1) max_id[max_len == 0] = -1 + load_id = set(max_id) + load_id.discard(-1) not_load_weight_name = [] + if weight_keys[0].startswith('modelStudent') or weight_keys[0].startswith( + 'modelTeacher'): + for match_idx in range(len(max_id)): + if max_id[match_idx] == -1: + not_load_weight_name.append(model_keys[match_idx]) + if len(not_load_weight_name) > 0: + logger.info('{} in model is not matched with pretrained weights, ' + 'and its will be trained from scratch'.format( + not_load_weight_name)) - for match_idx in range(len(max_id)): - if max_id[match_idx] == -1: - not_load_weight_name.append(model_keys[match_idx]) - if len(not_load_weight_name) > 0: - logger.info('{} in model is not matched with pretrained weights, ' - 'and its will be trained from scratch'.format( - not_load_weight_name)) + else: + for idx in range(len(weight_keys)): + if idx not in load_id: + not_load_weight_name.append(weight_keys[idx]) + + if len(not_load_weight_name) > 0: + logger.info('{} in pretrained weight is not used in the model, ' + 'and its will not be loaded'.format( + not_load_weight_name)) matched_keys = {} result_state_dict = {} for model_id, weight_id in enumerate(max_id):