Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
31bf269e
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
31bf269e
编写于
5月 12, 2020
作者:
F
FlyingQianMM
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add resume training from a checkpoint
上级
1af3e044
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
145 addition
and
50 deletion
+145
-50
docs/apis/models.md
docs/apis/models.md
+12
-6
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+13
-1
paddlex/cv/models/classifier.py
paddlex/cv/models/classifier.py
+21
-7
paddlex/cv/models/deeplabv3p.py
paddlex/cv/models/deeplabv3p.py
+24
-8
paddlex/cv/models/faster_rcnn.py
paddlex/cv/models/faster_rcnn.py
+21
-6
paddlex/cv/models/mask_rcnn.py
paddlex/cv/models/mask_rcnn.py
+23
-7
paddlex/cv/models/unet.py
paddlex/cv/models/unet.py
+9
-8
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+22
-7
未找到文件。
docs/apis/models.md
浏览文件 @
31bf269e
...
...
@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
#### 分类器训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5
, resume_checkpoint=None
)
> ```
>
> **参数:**
...
...
@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000)
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### 分类器评估函数接口
...
...
@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
#### YOLOv3训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5
, resume_checkpoint=None
)
> ```
>
> **参数:**
...
...
@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### YOLOv3评估函数接口
...
...
@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
#### FasterRCNN训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5
, resume_checkpoint=None
)
>
> ```
>
...
...
@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### FasterRCNN评估函数接口
...
...
@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
#### MaskRCNN训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5
, resume_checkpoint=None
)
>
> ```
>
...
...
@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### MaskRCNN评估函数接口
...
...
@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
#### DeepLabv3训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5
, resume_checkpoint=None
):
>
> ```
>
...
...
@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### DeepLabv3评估函数接口
...
...
@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
#### Unet训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5
, resume_checkpoint=None
):
> ```
>
> **参数:**
...
...
@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### Unet评估函数接口
...
...
paddlex/cv/models/base.py
浏览文件 @
31bf269e
...
...
@@ -213,6 +213,17 @@ class BaseAPI:
prune_program
(
self
,
prune_params_ratios
)
self
.
status
=
'Prune'
def
resume_checkpoint
(
self
,
path
,
startup_prog
=
None
):
if
not
osp
.
isdir
(
path
):
raise
Exception
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
if
osp
.
exists
(
osp
.
join
(
path
,
'model.pdparams'
)):
path
=
osp
.
join
(
path
,
'model'
)
if
startup_prog
is
None
:
startup_prog
=
fluid
.
default_startup_program
()
self
.
exe
.
run
(
startup_prog
)
fluid
.
load
(
self
.
train_prog
,
path
,
executor
=
self
.
exe
)
def
get_model_info
(
self
):
info
=
dict
()
info
[
'version'
]
=
paddlex
.
__version__
...
...
@@ -334,6 +345,7 @@ class BaseAPI:
num_epochs
,
train_dataset
,
train_batch_size
,
start_epoch
=
0
,
eval_dataset
=
None
,
save_interval_epochs
=
1
,
log_interval_steps
=
10
,
...
...
@@ -408,7 +420,7 @@ class BaseAPI:
best_accuracy_key
=
""
best_accuracy
=
-
1.0
best_model_epoch
=
1
for
i
in
range
(
num_epochs
):
for
i
in
range
(
start_epoch
,
num_epochs
):
records
=
list
()
step_start_time
=
time
.
time
()
epoch_start_time
=
time
.
time
()
...
...
paddlex/cv/models/classifier.py
浏览文件 @
31bf269e
...
...
@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI):
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
...
...
@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
...
...
@@ -155,15 +157,27 @@ class BaseClassifier(BaseAPI):
# 构建训练、验证、预测网络
self
.
build_program
()
# 初始化网络权重
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
if
resume_checkpoint
:
self
.
resume_checkpoint
(
path
=
resume_checkpoint
,
startup_prog
=
fluid
.
default_startup_program
())
scope
=
fluid
.
global_scope
()
v
=
scope
.
find_var
(
'@LR_DECAY_COUNTER@'
)
step
=
np
.
array
(
v
.
get_tensor
())[
0
]
if
v
else
0
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
start_epoch
=
step
//
num_steps_each_epoch
+
1
else
:
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
start_epoch
=
0
# 训练
self
.
train_loop
(
start_epoch
=
start_epoch
,
num_epochs
=
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
train_batch_size
,
...
...
paddlex/cv/models/deeplabv3p.py
浏览文件 @
31bf269e
...
...
@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI):
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
...
...
@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
...
...
@@ -279,14 +281,27 @@ class DeepLabv3p(BaseAPI):
# 构建训练、验证、预测网络
self
.
build_program
()
# 初始化网络权重
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
if
resume_checkpoint
:
self
.
resume_checkpoint
(
path
=
resume_checkpoint
,
startup_prog
=
fluid
.
default_startup_program
())
scope
=
fluid
.
global_scope
()
v
=
scope
.
find_var
(
'@LR_DECAY_COUNTER@'
)
step
=
np
.
array
(
v
.
get_tensor
())[
0
]
if
v
else
0
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
start_epoch
=
step
//
num_steps_each_epoch
+
1
else
:
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
start_epoch
=
0
# 训练
self
.
train_loop
(
start_epoch
=
start_epoch
,
num_epochs
=
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
train_batch_size
,
...
...
@@ -405,5 +420,6 @@ class DeepLabv3p(BaseAPI):
w
,
h
=
info
[
1
][
1
],
info
[
1
][
0
]
pred
=
pred
[
0
:
h
,
0
:
w
]
else
:
raise
Exception
(
"Unexpected info '{}' in im_info"
.
format
(
info
[
0
]))
raise
Exception
(
"Unexpected info '{}' in im_info"
.
format
(
info
[
0
]))
return
{
'label_map'
:
pred
,
'score_map'
:
result
[
1
]}
paddlex/cv/models/faster_rcnn.py
浏览文件 @
31bf269e
...
...
@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI):
metric
=
None
,
use_vdl
=
False
,
early_stop
=
False
,
early_stop_patience
=
5
):
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
...
...
@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
@@ -227,13 +229,26 @@ class FasterRCNN(BaseAPI):
fuse_bn
=
True
if
self
.
with_fpn
and
self
.
backbone
in
[
'ResNet18'
,
'ResNet50'
]:
fuse_bn
=
False
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
fuse_bn
=
fuse_bn
,
save_dir
=
save_dir
)
if
resume_checkpoint
:
self
.
resume_checkpoint
(
path
=
resume_checkpoint
,
startup_prog
=
fluid
.
default_startup_program
())
scope
=
fluid
.
global_scope
()
v
=
scope
.
find_var
(
'@LR_DECAY_COUNTER@'
)
step
=
np
.
array
(
v
.
get_tensor
())[
0
]
if
v
else
0
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
start_epoch
=
step
//
num_steps_each_epoch
+
1
else
:
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
fuse_bn
=
fuse_bn
,
save_dir
=
save_dir
)
start_epoch
=
0
# 训练
self
.
train_loop
(
start_epoch
=
start_epoch
,
num_epochs
=
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
train_batch_size
,
...
...
paddlex/cv/models/mask_rcnn.py
浏览文件 @
31bf269e
...
...
@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN):
metric
=
None
,
use_vdl
=
False
,
early_stop
=
False
,
early_stop_patience
=
5
):
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
...
...
@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN):
metric
=
'COCO'
else
:
raise
Exception
(
"train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
)
"train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
)
assert
metric
in
[
'COCO'
,
'VOC'
],
"Metric only support 'VOC' or 'COCO'"
self
.
metric
=
metric
if
not
self
.
trainable
:
...
...
@@ -193,13 +196,26 @@ class MaskRCNN(FasterRCNN):
fuse_bn
=
True
if
self
.
with_fpn
and
self
.
backbone
in
[
'ResNet18'
,
'ResNet50'
]:
fuse_bn
=
False
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
fuse_bn
=
fuse_bn
,
save_dir
=
save_dir
)
if
resume_checkpoint
:
self
.
resume_checkpoint
(
path
=
resume_checkpoint
,
startup_prog
=
fluid
.
default_startup_program
())
scope
=
fluid
.
global_scope
()
v
=
scope
.
find_var
(
'@LR_DECAY_COUNTER@'
)
step
=
np
.
array
(
v
.
get_tensor
())[
0
]
if
v
else
0
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
start_epoch
=
step
//
num_steps_each_epoch
+
1
else
:
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
fuse_bn
=
fuse_bn
,
save_dir
=
save_dir
)
start_epoch
=
0
# 训练
self
.
train_loop
(
start_epoch
=
start_epoch
,
num_epochs
=
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
train_batch_size
,
...
...
paddlex/cv/models/unet.py
浏览文件 @
31bf269e
...
...
@@ -121,7 +121,8 @@ class UNet(DeepLabv3p):
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
...
...
@@ -145,14 +146,14 @@ class UNet(DeepLabv3p):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
"""
return
super
(
UNet
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
early_stop_patience
)
return
super
(
UNet
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
early_stop_patience
,
resume_checkpoint
)
paddlex/cv/models/yolo_v3.py
浏览文件 @
31bf269e
...
...
@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI):
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
...
...
@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
@@ -231,14 +233,27 @@ class YOLOv3(BaseAPI):
# 构建训练、验证、预测网络
self
.
build_program
()
# 初始化网络权重
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
if
resume_checkpoint
:
self
.
resume_checkpoint
(
path
=
resume_checkpoint
,
startup_prog
=
fluid
.
default_startup_program
())
scope
=
fluid
.
global_scope
()
v
=
scope
.
find_var
(
'@LR_DECAY_COUNTER@'
)
step
=
np
.
array
(
v
.
get_tensor
())[
0
]
if
v
else
0
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
start_epoch
=
step
//
num_steps_each_epoch
+
1
else
:
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
)
start_epoch
=
0
# 训练
self
.
train_loop
(
start_epoch
=
start_epoch
,
num_epochs
=
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
train_batch_size
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录