提交 c31931eb 编写于 作者: L lyuwenyu

fix ShuffleNet problem

上级 569215a2
......@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
dependencies = ['paddle', 'numpy']
import paddle
from ppcls.modeling.architectures import alexnet as _alexnet
from ppcls.modeling.architectures import vgg as _vgg
from ppcls.modeling.architectures import resnet as _resnet
from ppcls.modeling.architectures import vgg as _vgg
from ppcls.modeling.architectures import resnet as _resnet
from ppcls.modeling.architectures import squeezenet as _squeezenet
from ppcls.modeling.architectures import densenet as _densenet
from ppcls.modeling.architectures import inception_v3 as _inception_v3
......@@ -32,13 +31,13 @@ from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3
from ppcls.modeling.architectures import resnext as _resnext
def _load_pretrained_parameters(model, name):
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(name)
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(
name)
path = paddle.utils.download.get_weights_path_from_url(url)
model.set_state_dict(paddle.load(path))
return model
def AlexNet(pretrained=False, **kwargs):
"""
......@@ -182,7 +181,7 @@ def ResNet50(pretrained=False, **kwargs):
model = _resnet.ResNet50(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNet50')
return model
......@@ -404,19 +403,19 @@ def GoogLeNet(pretrained=False, **kwargs):
return model
def ShuffleNet(pretrained=False, **kwargs):
def ShuffleNetV2_x0_25(pretrained=False, **kwargs):
"""
ShuffleNet
ShuffleNetV2_x0_25
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_dim: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `ShuffleNet` model depends on args.
model: nn.Layer. Specific `ShuffleNetV2_x0_25` model depends on args.
"""
model = _shufflenet_v2.ShuffleNet(**kwargs)
model = _shufflenet_v2.ShuffleNetV2_x0_25(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ShuffleNet')
model = _load_pretrained_parameters(model, 'ShuffleNetV2_x0_25')
return model
......@@ -744,7 +743,6 @@ def MobileNetV3_small_x1_25(pretrained=False, **kwargs):
return model
def ResNeXt101_32x4d(pretrained=False, **kwargs):
"""
ResNeXt101_32x4d
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册