未验证 提交 78c6b82f 编写于 作者: W wjm 提交者: GitHub

Fix bug in load weight (#8097)

* fix_weight_load

* fix_load_weight
上级 4ee5e38b
......@@ -155,7 +155,7 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
return a == b or a.endswith("." + b) or b.endswith("." + a)
def match(a, b):
if a.startswith('backbone.res5'):
if b.startswith('backbone.res5'):
b = b[9:]
return a == b or a.endswith("." + b)
......@@ -174,8 +174,11 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
max_id = match_matrix.argmax(1)
max_len = match_matrix.max(1)
max_id[max_len == 0] = -1
load_id = set(max_id)
load_id.discard(-1)
not_load_weight_name = []
if weight_keys[0].startswith('modelStudent') or weight_keys[0].startswith(
'modelTeacher'):
for match_idx in range(len(max_id)):
if max_id[match_idx] == -1:
not_load_weight_name.append(model_keys[match_idx])
......@@ -183,6 +186,16 @@ def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
logger.info('{} in model is not matched with pretrained weights, '
'and its will be trained from scratch'.format(
not_load_weight_name))
else:
for idx in range(len(weight_keys)):
if idx not in load_id:
not_load_weight_name.append(weight_keys[idx])
if len(not_load_weight_name) > 0:
logger.info('{} in pretrained weight is not used in the model, '
'and its will not be loaded'.format(
not_load_weight_name))
matched_keys = {}
result_state_dict = {}
for model_id, weight_id in enumerate(max_id):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册