Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
05624d19
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05624d19
编写于
6月 26, 2020
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove new tutorials
上级
5de66166
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
0 addition
and
416 deletion
+0
-416
new_tutorials/train/README.md
new_tutorials/train/README.md
+0
-21
new_tutorials/train/classification/mobilenetv2.py
new_tutorials/train/classification/mobilenetv2.py
+0
-47
new_tutorials/train/classification/resnet50.py
new_tutorials/train/classification/resnet50.py
+0
-56
new_tutorials/train/detection/faster_rcnn_r50_fpn.py
new_tutorials/train/detection/faster_rcnn_r50_fpn.py
+0
-49
new_tutorials/train/detection/mask_rcnn_r50_fpn.py
new_tutorials/train/detection/mask_rcnn_r50_fpn.py
+0
-48
new_tutorials/train/detection/yolov3_darknet53.py
new_tutorials/train/detection/yolov3_darknet53.py
+0
-50
new_tutorials/train/segmentation/deeplabv3p.py
new_tutorials/train/segmentation/deeplabv3p.py
+0
-51
new_tutorials/train/segmentation/hrnet.py
new_tutorials/train/segmentation/hrnet.py
+0
-47
new_tutorials/train/segmentation/unet.py
new_tutorials/train/segmentation/unet.py
+0
-47
tutorials/train/segmentation/fast_scnn.py
tutorials/train/segmentation/fast_scnn.py
+0
-0
未找到文件。
new_tutorials/train/README.md
已删除
100644 → 0
浏览文件 @
5de66166
# 使用教程——训练模型
本目录下整理了使用PaddleX训练模型的示例代码,代码中均提供了示例数据的自动下载,并均使用单张GPU卡进行训练。
|代码 | 模型任务 | 数据 |
|------|--------|---------|
|classification/mobilenetv2.py | 图像分类MobileNetV2 | 蔬菜分类 |
|classification/resnet50.py | 图像分类ResNet50 | 蔬菜分类 |
|detection/faster_rcnn_r50_fpn.py | 目标检测FasterRCNN | 昆虫检测 |
|detection/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 垃圾分拣 |
|segmentation/deeplabv3p.py | 语义分割DeepLabV3| 视盘分割 |
|segmentation/unet.py | 语义分割UNet | 视盘分割 |
|segmentation/hrnet.py | 语义分割HRNet | 视盘分割 |
|segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 |
## 开始训练
在安装PaddleX后,使用如下命令开始训练
```
python classification/mobilenetv2.py
```
new_tutorials/train/classification/mobilenetv2.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
paddlex.cls
import
transforms
import
paddlex
as
pdx
# 下载和解压蔬菜分类数据集
veg_dataset
=
'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx
.
utils
.
download_and_decompress
(
veg_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/cls_transforms.html#composedclstransforms
train_transforms
=
transforms
.
ComposedClsTransforms
(
mode
=
'train'
,
crop_size
=
[
224
,
224
])
eval_transforms
=
transforms
.
ComposedClsTransforms
(
mode
=
'eval'
,
crop_size
=
[
224
,
224
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/classification.html#imagenet
train_dataset
=
pdx
.
datasets
.
ImageNet
(
data_dir
=
'vegetables_cls'
,
file_list
=
'vegetables_cls/train_list.txt'
,
label_list
=
'vegetables_cls/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
ImageNet
(
data_dir
=
'vegetables_cls'
,
file_list
=
'vegetables_cls/val_list.txt'
,
label_list
=
'vegetables_cls/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/classification.html#resnet50
model
=
pdx
.
cls
.
MobileNetV2
(
num_classes
=
len
(
train_dataset
.
labels
))
model
.
train
(
num_epochs
=
10
,
train_dataset
=
train_dataset
,
train_batch_size
=
32
,
eval_dataset
=
eval_dataset
,
lr_decay_epochs
=
[
4
,
6
,
8
],
learning_rate
=
0.025
,
save_dir
=
'output/mobilenetv2'
,
use_vdl
=
True
)
new_tutorials/train/classification/resnet50.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
import
paddle.fluid
as
fluid
from
paddlex.cls
import
transforms
import
paddlex
as
pdx
# 下载和解压蔬菜分类数据集
veg_dataset
=
'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx
.
utils
.
download_and_decompress
(
veg_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/cls_transforms.html#composedclstransforms
train_transforms
=
transforms
.
ComposedClsTransforms
(
mode
=
'train'
,
crop_size
=
[
224
,
224
])
eval_transforms
=
transforms
.
ComposedClsTransforms
(
mode
=
'eval'
,
crop_size
=
[
224
,
224
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/classification.html#imagenet
train_dataset
=
pdx
.
datasets
.
ImageNet
(
data_dir
=
'vegetables_cls'
,
file_list
=
'vegetables_cls/train_list.txt'
,
label_list
=
'vegetables_cls/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
ImageNet
(
data_dir
=
'vegetables_cls'
,
file_list
=
'vegetables_cls/val_list.txt'
,
label_list
=
'vegetables_cls/labels.txt'
,
transforms
=
eval_transforms
)
# PaddleX支持自定义构建优化器
step_each_epoch
=
train_dataset
.
num_samples
//
32
learning_rate
=
fluid
.
layers
.
cosine_decay
(
learning_rate
=
0.025
,
step_each_epoch
=
step_each_epoch
,
epochs
=
10
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
learning_rate
,
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
4e-5
))
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/resnet50/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/classification.html#resnet50
model
=
pdx
.
cls
.
ResNet50
(
num_classes
=
len
(
train_dataset
.
labels
))
model
.
train
(
num_epochs
=
10
,
train_dataset
=
train_dataset
,
train_batch_size
=
32
,
eval_dataset
=
eval_dataset
,
optimizer
=
optimizer
,
save_dir
=
'output/resnet50'
,
use_vdl
=
True
)
new_tutorials/train/detection/faster_rcnn_r50_fpn.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
paddlex.det
import
transforms
import
paddlex
as
pdx
# 下载和解压昆虫检测数据集
insect_dataset
=
'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx
.
utils
.
download_and_decompress
(
insect_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedrcnntransforms
train_transforms
=
transforms
.
ComposedRCNNTransforms
(
mode
=
'train'
,
min_max_size
=
[
800
,
1333
])
eval_transforms
=
transforms
.
ComposedRCNNTransforms
(
mode
=
'eval'
,
min_max_size
=
[
800
,
1333
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#vocdetection
train_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/train_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/val_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/faster_rcnn_r50_fpn/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# num_classes 需要设置为包含背景类的类别数,即: 目标类别数量 + 1
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/detection.html#fasterrcnn
num_classes
=
len
(
train_dataset
.
labels
)
+
1
model
=
pdx
.
det
.
FasterRCNN
(
num_classes
=
num_classes
)
model
.
train
(
num_epochs
=
12
,
train_dataset
=
train_dataset
,
train_batch_size
=
2
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.0025
,
lr_decay_epochs
=
[
8
,
11
],
save_dir
=
'output/faster_rcnn_r50_fpn'
,
use_vdl
=
True
)
new_tutorials/train/detection/mask_rcnn_r50_fpn.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
paddlex.det
import
transforms
import
paddlex
as
pdx
# 下载和解压小度熊分拣数据集
xiaoduxiong_dataset
=
'https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_ins_det.tar.gz'
pdx
.
utils
.
download_and_decompress
(
xiaoduxiong_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedrcnntransforms
train_transforms
=
transforms
.
ComposedRCNNTransforms
(
mode
=
'train'
,
min_max_size
=
[
800
,
1333
])
eval_transforms
=
transforms
.
ComposedRCNNTransforms
(
mode
=
'eval'
,
min_max_size
=
[
800
,
1333
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#cocodetection
train_dataset
=
pdx
.
datasets
.
CocoDetection
(
data_dir
=
'xiaoduxiong_ins_det/JPEGImages'
,
ann_file
=
'xiaoduxiong_ins_det/train.json'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
CocoDetection
(
data_dir
=
'xiaoduxiong_ins_det/JPEGImages'
,
ann_file
=
'xiaoduxiong_ins_det/val.json'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/mask_rcnn_r50_fpn/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# num_classes 需要设置为包含背景类的类别数,即: 目标类别数量 + 1
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/instance_segmentation.html#maskrcnn
num_classes
=
len
(
train_dataset
.
labels
)
+
1
model
=
pdx
.
det
.
MaskRCNN
(
num_classes
=
num_classes
)
model
.
train
(
num_epochs
=
12
,
train_dataset
=
train_dataset
,
train_batch_size
=
1
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.00125
,
warmup_steps
=
10
,
lr_decay_epochs
=
[
8
,
11
],
save_dir
=
'output/mask_rcnn_r50_fpn'
,
use_vdl
=
True
)
new_tutorials/train/detection/yolov3_darknet53.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
paddlex.det
import
transforms
import
paddlex
as
pdx
# 下载和解压昆虫检测数据集
insect_dataset
=
'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx
.
utils
.
download_and_decompress
(
insect_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#composedyolotransforms
train_transforms
=
transforms
.
ComposedYOLOv3Transforms
(
mode
=
'train'
,
shape
=
[
608
,
608
])
eval_transforms
=
transforms
.
ComposedYOLOv3Transforms
(
mode
=
'eval'
,
shape
=
[
608
,
608
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/detection.html#vocdetection
train_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/train_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/val_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/detection.html#yolov3
num_classes
=
len
(
train_dataset
.
labels
)
model
=
pdx
.
det
.
YOLOv3
(
num_classes
=
num_classes
,
backbone
=
'DarkNet53'
)
model
.
train
(
num_epochs
=
270
,
train_dataset
=
train_dataset
,
train_batch_size
=
8
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.000125
,
lr_decay_epochs
=
[
210
,
240
],
save_dir
=
'output/yolov3_darknet53'
,
use_vdl
=
True
)
new_tutorials/train/segmentation/deeplabv3p.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
# 下载和解压视盘分割数据集
optic_dataset
=
'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx
.
utils
.
download_and_decompress
(
optic_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'train'
,
train_crop_size
=
[
769
,
769
])
eval_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'eval'
)
train_transforms
.
add_augmenters
([
transforms
.
RandomRotate
()
])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/train_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/val_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#deeplabv3p
num_classes
=
len
(
train_dataset
.
labels
)
model
=
pdx
.
seg
.
DeepLabv3p
(
num_classes
=
num_classes
)
model
.
train
(
num_epochs
=
40
,
train_dataset
=
train_dataset
,
train_batch_size
=
4
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.01
,
save_dir
=
'output/deeplab'
,
use_vdl
=
True
)
new_tutorials/train/segmentation/hrnet.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
# 下载和解压视盘分割数据集
optic_dataset
=
'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx
.
utils
.
download_and_decompress
(
optic_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'train'
,
train_crop_size
=
[
769
,
769
])
eval_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'eval'
)
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/train_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/val_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
num_classes
=
len
(
train_dataset
.
labels
)
model
=
pdx
.
seg
.
HRNet
(
num_classes
=
num_classes
)
model
.
train
(
num_epochs
=
20
,
train_dataset
=
train_dataset
,
train_batch_size
=
4
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.01
,
save_dir
=
'output/hrnet'
,
use_vdl
=
True
)
new_tutorials/train/segmentation/unet.py
已删除
100644 → 0
浏览文件 @
5de66166
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
# 下载和解压视盘分割数据集
optic_dataset
=
'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx
.
utils
.
download_and_decompress
(
optic_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'train'
,
train_crop_size
=
[
769
,
769
])
eval_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'eval'
)
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/train_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/val_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#unet
num_classes
=
len
(
train_dataset
.
labels
)
model
=
pdx
.
seg
.
UNet
(
num_classes
=
num_classes
)
model
.
train
(
num_epochs
=
20
,
train_dataset
=
train_dataset
,
train_batch_size
=
4
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.01
,
save_dir
=
'output/unet'
,
use_vdl
=
True
)
new_
tutorials/train/segmentation/fast_scnn.py
→
tutorials/train/segmentation/fast_scnn.py
浏览文件 @
05624d19
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录