mobilenetv2.py 1.6 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10
import os
from paddlex.cls import transforms
import paddlex as pdx

# 下载和解压蔬菜分类数据集
veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./')

# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
J
jiangjiajun 已提交
11 12
    transforms.RandomCrop(crop_size=224), 
    transforms.RandomHorizontalFlip(),
J
jiangjiajun 已提交
13 14 15 16
    transforms.Normalize()
])
eval_transforms = transforms.Compose([
    transforms.ResizeByShort(short_size=256),
J
jiangjiajun 已提交
17 18
    transforms.CenterCrop(crop_size=224), 
    transforms.Normalize()
J
jiangjiajun 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
])

# 定义训练和验证所用的数据集
train_dataset = pdx.datasets.ImageNet(
    data_dir='vegetables_cls',
    file_list='vegetables_cls/train_list.txt',
    label_list='vegetables_cls/labels.txt',
    transforms=train_transforms,
    shuffle=True)
eval_dataset = pdx.datasets.ImageNet(
    data_dir='vegetables_cls',
    file_list='vegetables_cls/val_list.txt',
    label_list='vegetables_cls/labels.txt',
    transforms=eval_transforms)

# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
model.train(
    num_epochs=10,
    train_dataset=train_dataset,
    train_batch_size=32,
    eval_dataset=eval_dataset,
    lr_decay_epochs=[4, 6, 8],
    learning_rate=0.025,
    save_dir='output/mobilenetv2',
    use_vdl=True)