未验证 提交 c182df29 编写于 作者: W wangguanzhong 提交者: GitHub

update pretrain match (#3244)

* updatee pretrain match

* update match for rcnn
上级 5ee9a605
......@@ -141,6 +141,69 @@ def load_weight(model, weight, optimizer=None):
return last_epoch
def match_state_dict(model_state_dict, weight_state_dict):
"""
Match between the model state dict and pretrained weight state dict.
Return the matched state dict.
The method supposes that all the names in pretrained weight state dict are
subclass of the names in models`, if the prefix 'backbone.' in pretrained weight
keys is stripped. And we could get the candidates for each model key. Then we
select the name with the longest matched size as the final match result. For
example, the model state dict has the name of
'backbone.res2.res2a.branch2a.conv.weight' and the pretrained weight as
name of 'res2.res2a.branch2a.conv.weight' and 'branch2a.conv.weight'. We
match the 'res2.res2a.branch2a.conv.weight' to the model key.
"""
model_keys = sorted(model_state_dict.keys())
weight_keys = sorted(weight_state_dict.keys())
def match(a, b):
if a.startswith('backbone.res5'):
# In Faster RCNN, res5 pretrained weights have prefix of backbone,
# however, the corresponding model weights have difficult prefix,
# bbox_head.
b = b.strip('backbone.')
return a == b or a.endswith("." + b)
match_matrix = np.zeros([len(model_keys), len(weight_keys)])
for i, m_k in enumerate(model_keys):
for j, w_k in enumerate(weight_keys):
if match(m_k, w_k):
match_matrix[i, j] = len(w_k)
max_id = match_matrix.argmax(1)
max_len = match_matrix.max(1)
max_id[max_len == 0] = -1
matched_keys = {}
result_state_dict = {}
for model_id, weight_id in enumerate(max_id):
if weight_id == -1:
continue
model_key = model_keys[model_id]
weight_key = weight_keys[weight_id]
weight_value = weight_state_dict[weight_key]
model_value_shape = list(model_state_dict[model_key].shape)
if list(weight_value.shape) != model_value_shape:
logger.info(
'The shape {} in pretrained weight {} is unmatched with '
'the shape {} in model {}. And the weight {} will not be '
'loaded'.format(weight_value.shape, weight_key,
model_value_shape, model_key, weight_key))
continue
assert model_key not in result_state_dict
result_state_dict[model_key] = weight_value
if weight_key in matched_keys:
raise ValueError('Ambiguity weight {} loaded, it matches at least '
'{} and {} in the model'.format(
weight_key, model_key, matched_keys[
weight_key]))
matched_keys[weight_key] = model_key
return result_state_dict
def load_pretrain_weight(model, pretrain_weight):
if is_url(pretrain_weight):
pretrain_weight = get_weights_path_dist(pretrain_weight)
......@@ -157,31 +220,7 @@ def load_pretrain_weight(model, pretrain_weight):
weights_path = path + '.pdparams'
param_state_dict = paddle.load(weights_path)
ignore_weights = set()
# hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone'
# while res5 module is located in bbox_head.head. Replace the prefix of
# res5 with 'bbox_head.head' to load pretrain weights correctly.
for k in param_state_dict.keys():
if 'backbone.res5' in k:
new_k = k.replace('backbone', 'bbox_head.head')
if new_k in model_dict.keys():
value = param_state_dict.pop(k)
param_state_dict[new_k] = value
for name, weight in param_state_dict.items():
if name in model_dict.keys():
if list(weight.shape) != list(model_dict[name].shape):
logger.info(
'{} not used, shape {} unmatched with {} in model.'.format(
name, weight.shape, list(model_dict[name].shape)))
ignore_weights.add(name)
else:
logger.info('Redundant weight {} and ignore it.'.format(name))
ignore_weights.add(name)
for weight in ignore_weights:
param_state_dict.pop(weight, None)
param_state_dict = match_state_dict(model_dict, param_state_dict)
model.set_dict(param_state_dict)
logger.info('Finish loading model weights: {}'.format(weights_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册