未验证 提交 00775c89 编写于 作者: W wangguanzhong 提交者: GitHub

fix loading rcnn pretrain (#2897)

上级 c7c0568f
......@@ -159,6 +159,16 @@ def load_pretrain_weight(model, pretrain_weight):
param_state_dict = paddle.load(weights_path)
ignore_weights = set()
# hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone'
# while res5 module is located in bbox_head.head. Replace the prefix of
# res5 with 'bbox_head.head' to load pretrain weights correctly.
for k in param_state_dict.keys():
if 'backbone.res5' in k:
new_k = k.replace('backbone', 'bbox_head.head')
if new_k in model_dict.keys():
value = param_state_dict.pop(k)
param_state_dict[new_k] = value
for name, weight in param_state_dict.items():
if name in model_dict.keys():
if list(weight.shape) != list(model_dict[name].shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册