Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
1c205a14
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1c205a14
编写于
5月 08, 2020
作者:
F
FlyingQianMM
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add early_stop policy
上级
ece29fe5
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
123 addition
and
23 deletion
+123
-23
docs/apis/models.md
docs/apis/models.md
+18
-6
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+11
-1
paddlex/cv/models/classifier.py
paddlex/cv/models/classifier.py
+9
-2
paddlex/cv/models/deeplabv3p.py
paddlex/cv/models/deeplabv3p.py
+9
-2
paddlex/cv/models/faster_rcnn.py
paddlex/cv/models/faster_rcnn.py
+9
-2
paddlex/cv/models/mask_rcnn.py
paddlex/cv/models/mask_rcnn.py
+9
-2
paddlex/cv/models/unet.py
paddlex/cv/models/unet.py
+13
-6
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+9
-2
paddlex/utils/utils.py
paddlex/utils/utils.py
+36
-0
未找到文件。
docs/apis/models.md
浏览文件 @
1c205a14
...
...
@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
#### 分类器训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05)
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05
, early_stop=False, early_stop_patience=5
)
> ```
>
> **参数:**
...
...
@@ -37,6 +37,8 @@ paddlex.cls.ResNet50(num_classes=1000)
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### 分类器评估函数接口
...
...
@@ -109,7 +111,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
#### YOLOv3训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05)
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05
, early_stop=False, early_stop_patience=5
)
> ```
>
> **参数:**
...
...
@@ -132,6 +134,8 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### YOLOv3评估函数接口
...
...
@@ -186,7 +190,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
#### FasterRCNN训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False)
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False
, early_stop=False, early_stop_patience=5
)
>
> ```
>
...
...
@@ -208,6 +212,8 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### FasterRCNN评估函数接口
...
...
@@ -264,7 +270,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
#### MaskRCNN训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False)
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False
, early_stop=False, early_stop_patience=5
)
>
> ```
>
...
...
@@ -286,6 +292,8 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### MaskRCNN评估函数接口
...
...
@@ -350,7 +358,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
#### DeepLabv3训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05):
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05
, early_stop=False, early_stop_patience=5
):
>
> ```
>
...
...
@@ -370,6 +378,8 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### DeepLabv3评估函数接口
...
...
@@ -427,7 +437,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
#### Unet训练函数接口
> ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05):
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05
, early_stop=False, early_stop_patience=5
):
> ```
>
> **参数:**
...
...
@@ -446,6 +456,8 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### Unet评估函数接口
...
...
paddlex/cv/models/base.py
浏览文件 @
1c205a14
...
...
@@ -24,6 +24,7 @@ import json
import
functools
import
paddlex.utils.logging
as
logging
from
paddlex.utils
import
seconds_to_hms
from
paddlex.utils.utils
import
EarlyStop
import
paddlex
from
collections
import
OrderedDict
from
os
import
path
as
osp
...
...
@@ -334,7 +335,9 @@ class BaseAPI:
save_interval_epochs
=
1
,
log_interval_steps
=
10
,
save_dir
=
'output'
,
use_vdl
=
False
):
use_vdl
=
False
,
early_stop
=
False
,
early_stop_patience
=
5
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
...
...
@@ -396,6 +399,9 @@ class BaseAPI:
train_step_component
=
OrderedDict
()
eval_component
=
OrderedDict
()
thresh
=
0.0001
if
early_stop
:
earlystop
=
EarlyStop
(
early_stop_patience
,
thresh
)
best_accuracy_key
=
""
best_accuracy
=
-
1.0
best_model_epoch
=
1
...
...
@@ -507,3 +513,7 @@ class BaseAPI:
'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
.
format
(
best_model_epoch
,
best_accuracy_key
,
best_accuracy
))
if
eval_dataset
is
not
None
:
if
early_stop
:
if
earlystop
(
current_accuracy
):
break
paddlex/cv/models/classifier.py
浏览文件 @
1c205a14
...
...
@@ -102,7 +102,9 @@ class BaseClassifier(BaseAPI):
lr_decay_gamma
=
0.1
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
):
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
"""训练。
Args:
...
...
@@ -124,6 +126,9 @@ class BaseClassifier(BaseAPI):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises:
ValueError: 模型从inference model进行加载。
...
...
@@ -158,7 +163,9 @@ class BaseClassifier(BaseAPI):
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
)
use_vdl
=
use_vdl
,
early_stop
=
early_stop
,
early_stop_patience
=
early_stop_patience
)
def
evaluate
(
self
,
eval_dataset
,
...
...
paddlex/cv/models/deeplabv3p.py
浏览文件 @
1c205a14
...
...
@@ -231,7 +231,9 @@ class DeepLabv3p(BaseAPI):
lr_decay_power
=
0.9
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
):
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
"""训练。
Args:
...
...
@@ -252,6 +254,9 @@ class DeepLabv3p(BaseAPI):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises:
ValueError: 模型从inference model进行加载。
...
...
@@ -288,7 +293,9 @@ class DeepLabv3p(BaseAPI):
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
)
use_vdl
=
use_vdl
,
early_stop
=
early_stop
,
early_stop_patience
=
early_stop_patience
)
def
evaluate
(
self
,
eval_dataset
,
...
...
paddlex/cv/models/faster_rcnn.py
浏览文件 @
1c205a14
...
...
@@ -163,7 +163,9 @@ class FasterRCNN(BaseAPI):
lr_decay_epochs
=
[
8
,
11
],
lr_decay_gamma
=
0.1
,
metric
=
None
,
use_vdl
=
False
):
use_vdl
=
False
,
early_stop
=
False
,
early_stop_patience
=
5
):
"""训练。
Args:
...
...
@@ -186,6 +188,9 @@ class FasterRCNN(BaseAPI):
lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
@@ -233,7 +238,9 @@ class FasterRCNN(BaseAPI):
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
)
use_vdl
=
use_vdl
,
early_stop
=
early_stop
,
early_stop_patience
=
early_stop_patience
)
def
evaluate
(
self
,
eval_dataset
,
...
...
paddlex/cv/models/mask_rcnn.py
浏览文件 @
1c205a14
...
...
@@ -128,7 +128,9 @@ class MaskRCNN(FasterRCNN):
lr_decay_epochs
=
[
8
,
11
],
lr_decay_gamma
=
0.1
,
metric
=
None
,
use_vdl
=
False
):
use_vdl
=
False
,
early_stop
=
False
,
early_stop_patience
=
5
):
"""训练。
Args:
...
...
@@ -151,6 +153,9 @@ class MaskRCNN(FasterRCNN):
lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。
use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
@@ -199,7 +204,9 @@ class MaskRCNN(FasterRCNN):
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
)
use_vdl
=
use_vdl
,
early_stop
=
early_stop
,
early_stop_patience
=
early_stop_patience
)
def
evaluate
(
self
,
eval_dataset
,
...
...
paddlex/cv/models/unet.py
浏览文件 @
1c205a14
...
...
@@ -117,7 +117,9 @@ class UNet(DeepLabv3p):
lr_decay_power
=
0.9
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
):
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
"""训练。
Args:
...
...
@@ -138,12 +140,17 @@ class UNet(DeepLabv3p):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises:
ValueError: 模型从inference model进行加载。
"""
return
super
(
UNet
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
)
return
super
(
UNet
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
early_stop_patience
)
paddlex/cv/models/yolo_v3.py
浏览文件 @
1c205a14
...
...
@@ -162,7 +162,9 @@ class YOLOv3(BaseAPI):
metric
=
None
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
):
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
):
"""训练。
Args:
...
...
@@ -188,6 +190,9 @@ class YOLOv3(BaseAPI):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises:
ValueError: 评估类型不在指定列表中。
...
...
@@ -238,7 +243,9 @@ class YOLOv3(BaseAPI):
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
)
use_vdl
=
use_vdl
,
early_stop
=
early_stop
,
early_stop_patience
=
early_stop_patience
)
def
evaluate
(
self
,
eval_dataset
,
...
...
paddlex/utils/utils.py
浏览文件 @
1c205a14
...
...
@@ -220,3 +220,39 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
len
(
vars_to_load
),
weights_dir
))
if
fuse_bn
:
fuse_bn_weights
(
exe
,
main_prog
,
weights_dir
)
class
EarlyStop
:
def
__init__
(
self
,
patience
,
thresh
):
self
.
patience
=
patience
self
.
counter
=
0
self
.
score
=
None
self
.
max
=
0
self
.
thresh
=
thresh
if
patience
<
1
:
raise
Exception
(
"Argument patience should be a positive integer."
)
def
__call__
(
self
,
current_score
):
if
self
.
score
is
None
:
self
.
score
=
current_score
return
False
elif
current_score
>
self
.
max
:
self
.
counter
=
0
self
.
score
=
current_score
self
.
max
=
current_score
return
False
else
:
if
(
abs
(
self
.
score
-
current_score
)
<
self
.
thresh
or
current_score
<
self
.
score
):
self
.
counter
+=
1
self
.
score
=
current_score
logging
.
debug
(
"EarlyStopping: %i / %i"
%
(
self
.
counter
,
self
.
patience
))
if
self
.
counter
>=
self
.
patience
:
logging
.
info
(
"EarlyStopping: Stop training"
)
return
True
return
False
else
:
self
.
counter
=
0
self
.
score
=
current_score
return
False
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录