提交 c31931eb 编写于 作者: L lyuwenyu

fix ShuffleNet problem

上级 569215a2
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
dependencies = ['paddle', 'numpy'] dependencies = ['paddle', 'numpy']
import paddle import paddle
from ppcls.modeling.architectures import alexnet as _alexnet from ppcls.modeling.architectures import alexnet as _alexnet
from ppcls.modeling.architectures import vgg as _vgg from ppcls.modeling.architectures import vgg as _vgg
from ppcls.modeling.architectures import resnet as _resnet from ppcls.modeling.architectures import resnet as _resnet
from ppcls.modeling.architectures import squeezenet as _squeezenet from ppcls.modeling.architectures import squeezenet as _squeezenet
from ppcls.modeling.architectures import densenet as _densenet from ppcls.modeling.architectures import densenet as _densenet
from ppcls.modeling.architectures import inception_v3 as _inception_v3 from ppcls.modeling.architectures import inception_v3 as _inception_v3
...@@ -32,13 +31,13 @@ from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3 ...@@ -32,13 +31,13 @@ from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3
from ppcls.modeling.architectures import resnext as _resnext from ppcls.modeling.architectures import resnext as _resnext
def _load_pretrained_parameters(model, name): def _load_pretrained_parameters(model, name):
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(name) url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(
name)
path = paddle.utils.download.get_weights_path_from_url(url) path = paddle.utils.download.get_weights_path_from_url(url)
model.set_state_dict(paddle.load(path)) model.set_state_dict(paddle.load(path))
return model return model
def AlexNet(pretrained=False, **kwargs): def AlexNet(pretrained=False, **kwargs):
""" """
...@@ -182,7 +181,7 @@ def ResNet50(pretrained=False, **kwargs): ...@@ -182,7 +181,7 @@ def ResNet50(pretrained=False, **kwargs):
model = _resnet.ResNet50(**kwargs) model = _resnet.ResNet50(**kwargs)
if pretrained: if pretrained:
model = _load_pretrained_parameters(model, 'ResNet50') model = _load_pretrained_parameters(model, 'ResNet50')
return model return model
...@@ -404,19 +403,19 @@ def GoogLeNet(pretrained=False, **kwargs): ...@@ -404,19 +403,19 @@ def GoogLeNet(pretrained=False, **kwargs):
return model return model
def ShuffleNet(pretrained=False, **kwargs): def ShuffleNetV2_x0_25(pretrained=False, **kwargs):
""" """
ShuffleNet ShuffleNetV2_x0_25
Args: Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise. pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs: kwargs:
class_dim: int=1000. Output dim of last fc layer. class_dim: int=1000. Output dim of last fc layer.
Returns: Returns:
model: nn.Layer. Specific `ShuffleNet` model depends on args. model: nn.Layer. Specific `ShuffleNetV2_x0_25` model depends on args.
""" """
model = _shufflenet_v2.ShuffleNet(**kwargs) model = _shufflenet_v2.ShuffleNetV2_x0_25(**kwargs)
if pretrained: if pretrained:
model = _load_pretrained_parameters(model, 'ShuffleNet') model = _load_pretrained_parameters(model, 'ShuffleNetV2_x0_25')
return model return model
...@@ -744,7 +743,6 @@ def MobileNetV3_small_x1_25(pretrained=False, **kwargs): ...@@ -744,7 +743,6 @@ def MobileNetV3_small_x1_25(pretrained=False, **kwargs):
return model return model
def ResNeXt101_32x4d(pretrained=False, **kwargs): def ResNeXt101_32x4d(pretrained=False, **kwargs):
""" """
ResNeXt101_32x4d ResNeXt101_32x4d
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册