未验证 提交 9ca4f40f 编写于 作者: W WJJ1995 提交者: GitHub

fixed backbone bug (#684)

上级 7220d931
...@@ -302,7 +302,7 @@ def _resnet(arch: str, ...@@ -302,7 +302,7 @@ def _resnet(arch: str,
**kwargs: Any) -> ResNet: **kwargs: Any) -> ResNet:
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if pretrained: if pretrained:
state_dict = get_weights_path_from_url(model_urls[arch]) state_dict = paddle.load(get_weights_path_from_url(model_urls[arch]))
model.load_dict(state_dict) model.load_dict(state_dict)
return model return model
......
...@@ -109,7 +109,7 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, ...@@ -109,7 +109,7 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool,
kwargs['init_weights'] = False kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained: if pretrained:
state_dict = get_weights_path_from_url(model_urls[arch]) state_dict = paddle.load(get_weights_path_from_url(model_urls[arch]))
model.load_dict(state_dict) model.load_dict(state_dict)
return model return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册