未验证 提交 82646ce0 编写于 作者: W wangguanzhong 提交者: GitHub

fix checkpoint log when weight is not loaded (#4885)

* fix checkpoint log when weight not loaded

* fix loading pretrain weight in mask rcnn

* discard unmatch id instead of remove
上级 e27674ff
......@@ -124,7 +124,7 @@ def match_state_dict(model_state_dict, weight_state_dict):
weight_keys = sorted(weight_state_dict.keys())
def match(a, b):
if a.startswith('backbone.res5'):
if b.startswith('backbone.res5'):
# In Faster RCNN, res5 pretrained weights have prefix of backbone,
# however, the corresponding model weights have difficult prefix,
# bbox_head.
......@@ -139,10 +139,14 @@ def match_state_dict(model_state_dict, weight_state_dict):
max_id = match_matrix.argmax(1)
max_len = match_matrix.max(1)
max_id[max_len == 0] = -1
load_id = set(max_id)
load_id.discard(-1)
not_load_weight_name = []
for match_idx in range(len(max_id)):
if match_idx < len(weight_keys) and max_id[match_idx] == -1:
not_load_weight_name.append(weight_keys[match_idx])
for idx in range(len(weight_keys)):
if idx not in load_id:
not_load_weight_name.append(weight_keys[idx])
if len(not_load_weight_name) > 0:
logger.info('{} in pretrained weight is not used in the model, '
'and its will not be loaded'.format(not_load_weight_name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册