hrnet.py 2.2 KB
Newer Older
J
jiangjiajun 已提交
1 2
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
F
FlyingQianMM 已提交
3 4 5 6 7 8 9 10 11 12 13
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import paddlex as pdx
from paddlex.seg import transforms

# 下载和解压视盘分割数据集
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')

# 定义训练和验证时的transforms
J
jiangjiajun 已提交
14
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
F
FlyingQianMM 已提交
15
train_transforms = transforms.Compose([
16 17
    transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
    transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
F
FlyingQianMM 已提交
18 19 20
])

eval_transforms = transforms.Compose([
21 22
    transforms.ResizeByLong(long_size=512),
    transforms.Padding(target_size=512), transforms.Normalize()
F
FlyingQianMM 已提交
23 24 25
])

# 定义训练和验证所用的数据集
J
jiangjiajun 已提交
26
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
F
FlyingQianMM 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39
train_dataset = pdx.datasets.SegDataset(
    data_dir='optic_disc_seg',
    file_list='optic_disc_seg/train_list.txt',
    label_list='optic_disc_seg/labels.txt',
    transforms=train_transforms,
    shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
    data_dir='optic_disc_seg',
    file_list='optic_disc_seg/val_list.txt',
    label_list='optic_disc_seg/labels.txt',
    transforms=eval_transforms)

# 初始化模型,并进行训练
40
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
F
FlyingQianMM 已提交
41
num_classes = len(train_dataset.labels)
J
jiangjiajun 已提交
42 43

# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-hrnet
F
FlyingQianMM 已提交
44
model = pdx.seg.HRNet(num_classes=num_classes)
J
jiangjiajun 已提交
45 46 47

# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
F
FlyingQianMM 已提交
48 49 50 51 52 53 54 55
model.train(
    num_epochs=20,
    train_dataset=train_dataset,
    train_batch_size=4,
    eval_dataset=eval_dataset,
    learning_rate=0.01,
    save_dir='output/hrnet',
    use_vdl=True)