未验证 提交 526e112a 编写于 作者: J Jason 提交者: GitHub

Merge pull request #24 from FlyingQianMM/develop_draw

add early_stop policy
...@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000) ...@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
#### 分类器训练函数接口 #### 分类器训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05) > train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -37,6 +37,8 @@ paddlex.cls.ResNet50(num_classes=1000) ...@@ -37,6 +37,8 @@ paddlex.cls.ResNet50(num_classes=1000)
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### 分类器评估函数接口 #### 分类器评估函数接口
...@@ -109,7 +111,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ ...@@ -109,7 +111,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
#### YOLOv3训练函数接口 #### YOLOv3训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05) > train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5)
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -132,6 +134,8 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ ...@@ -132,6 +134,8 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在PascalVOC数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### YOLOv3评估函数接口 #### YOLOv3评估函数接口
...@@ -186,7 +190,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -186,7 +190,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
#### FasterRCNN训练函数接口 #### FasterRCNN训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False) > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
> >
> ``` > ```
> >
...@@ -208,6 +212,8 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -208,6 +212,8 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。 > > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### FasterRCNN评估函数接口 #### FasterRCNN评估函数接口
...@@ -264,7 +270,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -264,7 +270,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
#### MaskRCNN训练函数接口 #### MaskRCNN训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False) > train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5)
> >
> ``` > ```
> >
...@@ -286,6 +292,8 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -286,6 +292,8 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。 > > - **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
> > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### MaskRCNN评估函数接口 #### MaskRCNN评估函数接口
...@@ -350,7 +358,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -350,7 +358,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
#### DeepLabv3训练函数接口 #### DeepLabv3训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05): > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
> >
> ``` > ```
> >
...@@ -370,6 +378,8 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -370,6 +378,8 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### DeepLabv3评估函数接口 #### DeepLabv3评估函数接口
...@@ -427,7 +437,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -427,7 +437,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
#### Unet训练函数接口 #### Unet训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05): > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5):
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -446,6 +456,8 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -446,6 +456,8 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 > > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
#### Unet评估函数接口 #### Unet评估函数接口
......
...@@ -24,6 +24,7 @@ import json ...@@ -24,6 +24,7 @@ import json
import functools import functools
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
from paddlex.utils import seconds_to_hms from paddlex.utils import seconds_to_hms
from paddlex.utils.utils import EarlyStop
import paddlex import paddlex
from collections import OrderedDict from collections import OrderedDict
from os import path as osp from os import path as osp
...@@ -334,7 +335,9 @@ class BaseAPI: ...@@ -334,7 +335,9 @@ class BaseAPI:
save_interval_epochs=1, save_interval_epochs=1,
log_interval_steps=10, log_interval_steps=10,
save_dir='output', save_dir='output',
use_vdl=False): use_vdl=False,
early_stop=False,
early_stop_patience=5):
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
...@@ -396,6 +399,9 @@ class BaseAPI: ...@@ -396,6 +399,9 @@ class BaseAPI:
train_step_component = OrderedDict() train_step_component = OrderedDict()
eval_component = OrderedDict() eval_component = OrderedDict()
thresh = 0.0001
if early_stop:
earlystop = EarlyStop(early_stop_patience, thresh)
best_accuracy_key = "" best_accuracy_key = ""
best_accuracy = -1.0 best_accuracy = -1.0
best_model_epoch = 1 best_model_epoch = 1
...@@ -507,3 +513,6 @@ class BaseAPI: ...@@ -507,3 +513,6 @@ class BaseAPI:
'Current evaluated best model in eval_dataset is epoch_{}, {}={}' 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
.format(best_model_epoch, best_accuracy_key, .format(best_model_epoch, best_accuracy_key,
best_accuracy)) best_accuracy))
if eval_dataset is not None and early_stop:
if earlystop(current_accuracy):
break
...@@ -102,7 +102,9 @@ class BaseClassifier(BaseAPI): ...@@ -102,7 +102,9 @@ class BaseClassifier(BaseAPI):
lr_decay_gamma=0.1, lr_decay_gamma=0.1,
use_vdl=False, use_vdl=False,
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05): eval_metric_loss=0.05,
early_stop=False,
early_stop_patience=5):
"""训练。 """训练。
Args: Args:
...@@ -124,6 +126,9 @@ class BaseClassifier(BaseAPI): ...@@ -124,6 +126,9 @@ class BaseClassifier(BaseAPI):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT', sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。 eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
...@@ -158,7 +163,9 @@ class BaseClassifier(BaseAPI): ...@@ -158,7 +163,9 @@ class BaseClassifier(BaseAPI):
save_interval_epochs=save_interval_epochs, save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps, log_interval_steps=log_interval_steps,
save_dir=save_dir, save_dir=save_dir,
use_vdl=use_vdl) use_vdl=use_vdl,
early_stop=early_stop,
early_stop_patience=early_stop_patience)
def evaluate(self, def evaluate(self,
eval_dataset, eval_dataset,
......
...@@ -231,7 +231,9 @@ class DeepLabv3p(BaseAPI): ...@@ -231,7 +231,9 @@ class DeepLabv3p(BaseAPI):
lr_decay_power=0.9, lr_decay_power=0.9,
use_vdl=False, use_vdl=False,
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05): eval_metric_loss=0.05,
early_stop=False,
early_stop_patience=5):
"""训练。 """训练。
Args: Args:
...@@ -252,6 +254,9 @@ class DeepLabv3p(BaseAPI): ...@@ -252,6 +254,9 @@ class DeepLabv3p(BaseAPI):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT', sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。 eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
...@@ -288,7 +293,9 @@ class DeepLabv3p(BaseAPI): ...@@ -288,7 +293,9 @@ class DeepLabv3p(BaseAPI):
save_interval_epochs=save_interval_epochs, save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps, log_interval_steps=log_interval_steps,
save_dir=save_dir, save_dir=save_dir,
use_vdl=use_vdl) use_vdl=use_vdl,
early_stop=early_stop,
early_stop_patience=early_stop_patience)
def evaluate(self, def evaluate(self,
eval_dataset, eval_dataset,
......
...@@ -163,7 +163,9 @@ class FasterRCNN(BaseAPI): ...@@ -163,7 +163,9 @@ class FasterRCNN(BaseAPI):
lr_decay_epochs=[8, 11], lr_decay_epochs=[8, 11],
lr_decay_gamma=0.1, lr_decay_gamma=0.1,
metric=None, metric=None,
use_vdl=False): use_vdl=False,
early_stop=False,
early_stop_patience=5):
"""训练。 """训练。
Args: Args:
...@@ -186,6 +188,9 @@ class FasterRCNN(BaseAPI): ...@@ -186,6 +188,9 @@ class FasterRCNN(BaseAPI):
lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。 lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。 metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。 use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -233,7 +238,9 @@ class FasterRCNN(BaseAPI): ...@@ -233,7 +238,9 @@ class FasterRCNN(BaseAPI):
save_interval_epochs=save_interval_epochs, save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps, log_interval_steps=log_interval_steps,
save_dir=save_dir, save_dir=save_dir,
use_vdl=use_vdl) use_vdl=use_vdl,
early_stop=early_stop,
early_stop_patience=early_stop_patience)
def evaluate(self, def evaluate(self,
eval_dataset, eval_dataset,
......
...@@ -128,7 +128,9 @@ class MaskRCNN(FasterRCNN): ...@@ -128,7 +128,9 @@ class MaskRCNN(FasterRCNN):
lr_decay_epochs=[8, 11], lr_decay_epochs=[8, 11],
lr_decay_gamma=0.1, lr_decay_gamma=0.1,
metric=None, metric=None,
use_vdl=False): use_vdl=False,
early_stop=False,
early_stop_patience=5):
"""训练。 """训练。
Args: Args:
...@@ -151,6 +153,9 @@ class MaskRCNN(FasterRCNN): ...@@ -151,6 +153,9 @@ class MaskRCNN(FasterRCNN):
lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。 lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。 metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。
use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。 use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -199,7 +204,9 @@ class MaskRCNN(FasterRCNN): ...@@ -199,7 +204,9 @@ class MaskRCNN(FasterRCNN):
save_interval_epochs=save_interval_epochs, save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps, log_interval_steps=log_interval_steps,
save_dir=save_dir, save_dir=save_dir,
use_vdl=use_vdl) use_vdl=use_vdl,
early_stop=early_stop,
early_stop_patience=early_stop_patience)
def evaluate(self, def evaluate(self,
eval_dataset, eval_dataset,
......
...@@ -117,7 +117,9 @@ class UNet(DeepLabv3p): ...@@ -117,7 +117,9 @@ class UNet(DeepLabv3p):
lr_decay_power=0.9, lr_decay_power=0.9,
use_vdl=False, use_vdl=False,
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05): eval_metric_loss=0.05,
early_stop=False,
early_stop_patience=5):
"""训练。 """训练。
Args: Args:
...@@ -138,12 +140,17 @@ class UNet(DeepLabv3p): ...@@ -138,12 +140,17 @@ class UNet(DeepLabv3p):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT', sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。 eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
return super(UNet, self).train( return super(
num_epochs, train_dataset, train_batch_size, eval_dataset, UNet,
save_interval_epochs, log_interval_steps, save_dir, self).train(num_epochs, train_dataset, train_batch_size,
pretrain_weights, optimizer, learning_rate, lr_decay_power, eval_dataset, save_interval_epochs, log_interval_steps,
use_vdl, sensitivities_file, eval_metric_loss) save_dir, pretrain_weights, optimizer, learning_rate,
lr_decay_power, use_vdl, sensitivities_file,
eval_metric_loss, early_stop, early_stop_patience)
...@@ -162,7 +162,9 @@ class YOLOv3(BaseAPI): ...@@ -162,7 +162,9 @@ class YOLOv3(BaseAPI):
metric=None, metric=None,
use_vdl=False, use_vdl=False,
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05): eval_metric_loss=0.05,
early_stop=False,
early_stop_patience=5):
"""训练。 """训练。
Args: Args:
...@@ -188,6 +190,9 @@ class YOLOv3(BaseAPI): ...@@ -188,6 +190,9 @@ class YOLOv3(BaseAPI):
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT', sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。 eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -238,7 +243,9 @@ class YOLOv3(BaseAPI): ...@@ -238,7 +243,9 @@ class YOLOv3(BaseAPI):
save_interval_epochs=save_interval_epochs, save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps, log_interval_steps=log_interval_steps,
save_dir=save_dir, save_dir=save_dir,
use_vdl=use_vdl) use_vdl=use_vdl,
early_stop=early_stop,
early_stop_patience=early_stop_patience)
def evaluate(self, def evaluate(self,
eval_dataset, eval_dataset,
......
...@@ -220,3 +220,39 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False): ...@@ -220,3 +220,39 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
len(vars_to_load), weights_dir)) len(vars_to_load), weights_dir))
if fuse_bn: if fuse_bn:
fuse_bn_weights(exe, main_prog, weights_dir) fuse_bn_weights(exe, main_prog, weights_dir)
class EarlyStop:
def __init__(self, patience, thresh):
self.patience = patience
self.counter = 0
self.score = None
self.max = 0
self.thresh = thresh
if patience < 1:
raise Exception("Argument patience should be a positive integer.")
def __call__(self, current_score):
if self.score is None:
self.score = current_score
return False
elif current_score > self.max:
self.counter = 0
self.score = current_score
self.max = current_score
return False
else:
if (abs(self.score - current_score) < self.thresh
or current_score < self.score):
self.counter += 1
self.score = current_score
logging.debug(
"EarlyStopping: %i / %i" % (self.counter, self.patience))
if self.counter >= self.patience:
logging.info("EarlyStopping: Stop training")
return True
return False
else:
self.counter = 0
self.score = current_score
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册