未验证 提交 78c6b82f 编写于 作者: W wjm 提交者: GitHub

Fix bug in load weight (#8097)

* fix_weight_load

* fix_load_weight
上级 4ee5e38b
...@@ -155,7 +155,7 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'): ...@@ -155,7 +155,7 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
return a == b or a.endswith("." + b) or b.endswith("." + a) return a == b or a.endswith("." + b) or b.endswith("." + a)
def match(a, b): def match(a, b):
if a.startswith('backbone.res5'): if b.startswith('backbone.res5'):
b = b[9:] b = b[9:]
return a == b or a.endswith("." + b) return a == b or a.endswith("." + b)
...@@ -174,15 +174,28 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'): ...@@ -174,15 +174,28 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
max_id = match_matrix.argmax(1) max_id = match_matrix.argmax(1)
max_len = match_matrix.max(1) max_len = match_matrix.max(1)
max_id[max_len == 0] = -1 max_id[max_len == 0] = -1
load_id = set(max_id)
load_id.discard(-1)
not_load_weight_name = [] not_load_weight_name = []
if weight_keys[0].startswith('modelStudent') or weight_keys[0].startswith(
'modelTeacher'):
for match_idx in range(len(max_id)):
if max_id[match_idx] == -1:
not_load_weight_name.append(model_keys[match_idx])
if len(not_load_weight_name) > 0:
logger.info('{} in model is not matched with pretrained weights, '
'and its will be trained from scratch'.format(
not_load_weight_name))
for match_idx in range(len(max_id)): else:
if max_id[match_idx] == -1: for idx in range(len(weight_keys)):
not_load_weight_name.append(model_keys[match_idx]) if idx not in load_id:
if len(not_load_weight_name) > 0: not_load_weight_name.append(weight_keys[idx])
logger.info('{} in model is not matched with pretrained weights, '
'and its will be trained from scratch'.format( if len(not_load_weight_name) > 0:
not_load_weight_name)) logger.info('{} in pretrained weight is not used in the model, '
'and its will not be loaded'.format(
not_load_weight_name))
matched_keys = {} matched_keys = {}
result_state_dict = {} result_state_dict = {}
for model_id, weight_id in enumerate(max_id): for model_id, weight_id in enumerate(max_id):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册