diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 22bffff7e9130db2421d9ff9a68e8b893cc3f6e3..0a4650402b11cdfb12d000c40f89d7740d7571eb 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -141,6 +141,69 @@ def load_weight(model, weight, optimizer=None): return last_epoch +def match_state_dict(model_state_dict, weight_state_dict): + """ + Match between the model state dict and pretrained weight state dict. + Return the matched state dict. + + The method supposes that all the names in pretrained weight state dict are + subclass of the names in models`, if the prefix 'backbone.' in pretrained weight + keys is stripped. And we could get the candidates for each model key. Then we + select the name with the longest matched size as the final match result. For + example, the model state dict has the name of + 'backbone.res2.res2a.branch2a.conv.weight' and the pretrained weight as + name of 'res2.res2a.branch2a.conv.weight' and 'branch2a.conv.weight'. We + match the 'res2.res2a.branch2a.conv.weight' to the model key. + """ + + model_keys = sorted(model_state_dict.keys()) + weight_keys = sorted(weight_state_dict.keys()) + + def match(a, b): + if a.startswith('backbone.res5'): + # In Faster RCNN, res5 pretrained weights have prefix of backbone, + # however, the corresponding model weights have difficult prefix, + # bbox_head. + b = b.strip('backbone.') + return a == b or a.endswith("." + b) + + match_matrix = np.zeros([len(model_keys), len(weight_keys)]) + for i, m_k in enumerate(model_keys): + for j, w_k in enumerate(weight_keys): + if match(m_k, w_k): + match_matrix[i, j] = len(w_k) + max_id = match_matrix.argmax(1) + max_len = match_matrix.max(1) + max_id[max_len == 0] = -1 + matched_keys = {} + result_state_dict = {} + for model_id, weight_id in enumerate(max_id): + if weight_id == -1: + continue + model_key = model_keys[model_id] + weight_key = weight_keys[weight_id] + weight_value = weight_state_dict[weight_key] + model_value_shape = list(model_state_dict[model_key].shape) + + if list(weight_value.shape) != model_value_shape: + logger.info( + 'The shape {} in pretrained weight {} is unmatched with ' + 'the shape {} in model {}. And the weight {} will not be ' + 'loaded'.format(weight_value.shape, weight_key, + model_value_shape, model_key, weight_key)) + continue + + assert model_key not in result_state_dict + result_state_dict[model_key] = weight_value + if weight_key in matched_keys: + raise ValueError('Ambiguity weight {} loaded, it matches at least ' + '{} and {} in the model'.format( + weight_key, model_key, matched_keys[ + weight_key])) + matched_keys[weight_key] = model_key + return result_state_dict + + def load_pretrain_weight(model, pretrain_weight): if is_url(pretrain_weight): pretrain_weight = get_weights_path_dist(pretrain_weight) @@ -157,31 +220,7 @@ def load_pretrain_weight(model, pretrain_weight): weights_path = path + '.pdparams' param_state_dict = paddle.load(weights_path) - ignore_weights = set() - - # hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone' - # while res5 module is located in bbox_head.head. Replace the prefix of - # res5 with 'bbox_head.head' to load pretrain weights correctly. - for k in param_state_dict.keys(): - if 'backbone.res5' in k: - new_k = k.replace('backbone', 'bbox_head.head') - if new_k in model_dict.keys(): - value = param_state_dict.pop(k) - param_state_dict[new_k] = value - - for name, weight in param_state_dict.items(): - if name in model_dict.keys(): - if list(weight.shape) != list(model_dict[name].shape): - logger.info( - '{} not used, shape {} unmatched with {} in model.'.format( - name, weight.shape, list(model_dict[name].shape))) - ignore_weights.add(name) - else: - logger.info('Redundant weight {} and ignore it.'.format(name)) - ignore_weights.add(name) - - for weight in ignore_weights: - param_state_dict.pop(weight, None) + param_state_dict = match_state_dict(model_dict, param_state_dict) model.set_dict(param_state_dict) logger.info('Finish loading model weights: {}'.format(weights_path))