未验证 提交 a1036d6e 编写于 作者: J Jason 提交者: GitHub

Merge pull request #38 from FlyingQianMM/develop_qh

add resume training from a checkpoint
...@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000) ...@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
#### 分类器训练函数接口 #### 分类器训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000) ...@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000)
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### 分类器评估函数接口 #### 分类器评估函数接口
...@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ ...@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
#### YOLOv3训练函数接口 #### YOLOv3训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ ...@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### YOLOv3评估函数接口 #### YOLOv3评估函数接口
...@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
#### FasterRCNN训练函数接口 #### FasterRCNN训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> >
> ``` > ```
> >
...@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### FasterRCNN评估函数接口 #### FasterRCNN评估函数接口
...@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
#### MaskRCNN训练函数接口 #### MaskRCNN训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> >
> ``` > ```
> >
...@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### MaskRCNN评估函数接口 #### MaskRCNN评估函数接口
...@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
#### DeepLabv3训练函数接口 #### DeepLabv3训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5): > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
> >
> ``` > ```
> >
...@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### DeepLabv3评估函数接口 #### DeepLabv3评估函数接口
...@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
#### Unet训练函数接口 #### Unet训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5): > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### Unet评估函数接口 #### Unet评估函数接口
......
...@@ -70,6 +70,8 @@ class BaseAPI: ...@@ -70,6 +70,8 @@ class BaseAPI:
self.sync_bn = False self.sync_bn = False
# 当前模型状态 # 当前模型状态
self.status = 'Normal' self.status = 'Normal'
# 已完成迭代轮数,为恢复训练时的起始轮数
self.completed_epochs = 0
def _get_single_card_bs(self, batch_size): def _get_single_card_bs(self, batch_size):
if batch_size % len(self.places) == 0: if batch_size % len(self.places) == 0:
...@@ -182,24 +184,39 @@ class BaseAPI: ...@@ -182,24 +184,39 @@ class BaseAPI:
fuse_bn=False, fuse_bn=False,
save_dir='.', save_dir='.',
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05): eval_metric_loss=0.05,
pretrain_dir = osp.join(save_dir, 'pretrain') resume_checkpoint=None):
if not os.path.isdir(pretrain_dir): if not resume_checkpoint:
if os.path.exists(pretrain_dir): pretrain_dir = osp.join(save_dir, 'pretrain')
os.remove(pretrain_dir) if not os.path.isdir(pretrain_dir):
os.makedirs(pretrain_dir) if os.path.exists(pretrain_dir):
if hasattr(self, 'backbone'): os.remove(pretrain_dir)
backbone = self.backbone os.makedirs(pretrain_dir)
else: if hasattr(self, 'backbone'):
backbone = self.__class__.__name__ backbone = self.backbone
pretrain_weights = get_pretrain_weights( else:
pretrain_weights, self.model_type, backbone, pretrain_dir) backbone = self.__class__.__name__
pretrain_weights = get_pretrain_weights(
pretrain_weights, self.model_type, backbone, pretrain_dir)
if startup_prog is None: if startup_prog is None:
startup_prog = fluid.default_startup_program() startup_prog = fluid.default_startup_program()
self.exe.run(startup_prog) self.exe.run(startup_prog)
if pretrain_weights is not None: if resume_checkpoint:
logging.info(
"Resume checkpoint from {}.".format(resume_checkpoint),
use_color=True)
paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, resume_checkpoint, resume=True)
if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
raise Exception(
"There's not model.yml in {}".format(resume_checkpoint))
with open(osp.join(resume_checkpoint, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader)
self.completed_epochs = info['completed_epochs']
elif pretrain_weights is not None:
logging.info( logging.info(
"Load pretrain weights from {}.".format(pretrain_weights), use_color=True) "Load pretrain weights from {}.".format(pretrain_weights),
use_color=True)
paddlex.utils.utils.load_pretrain_weights( paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, pretrain_weights, fuse_bn) self.exe, self.train_prog, pretrain_weights, fuse_bn)
# 进行裁剪 # 进行裁剪
...@@ -211,7 +228,8 @@ class BaseAPI: ...@@ -211,7 +228,8 @@ class BaseAPI:
from .slim.prune import get_params_ratios, prune_program from .slim.prune import get_params_ratios, prune_program
logging.info( logging.info(
"Start to prune program with eval_metric_loss = {}".format( "Start to prune program with eval_metric_loss = {}".format(
eval_metric_loss), use_color=True) eval_metric_loss),
use_color=True)
origin_flops = paddleslim.analysis.flops(self.test_prog) origin_flops = paddleslim.analysis.flops(self.test_prog)
prune_params_ratios = get_params_ratios( prune_params_ratios = get_params_ratios(
sensitivities_file, eval_metric_loss=eval_metric_loss) sensitivities_file, eval_metric_loss=eval_metric_loss)
...@@ -220,7 +238,8 @@ class BaseAPI: ...@@ -220,7 +238,8 @@ class BaseAPI:
remaining_ratio = current_flops / origin_flops remaining_ratio = current_flops / origin_flops
logging.info( logging.info(
"Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}" "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
.format(origin_flops, current_flops, remaining_ratio), use_color=True) .format(origin_flops, current_flops, remaining_ratio),
use_color=True)
self.status = 'Prune' self.status = 'Prune'
def get_model_info(self): def get_model_info(self):
...@@ -258,6 +277,7 @@ class BaseAPI: ...@@ -258,6 +277,7 @@ class BaseAPI:
name = op.__class__.__name__ name = op.__class__.__name__
attr = op.__dict__ attr = op.__dict__
info['Transforms'].append({name: attr}) info['Transforms'].append({name: attr})
info['completed_epochs'] = self.completed_epochs
return info return info
def save_model(self, save_dir): def save_model(self, save_dir):
...@@ -418,7 +438,8 @@ class BaseAPI: ...@@ -418,7 +438,8 @@ class BaseAPI:
best_accuracy_key = "" best_accuracy_key = ""
best_accuracy = -1.0 best_accuracy = -1.0
best_model_epoch = -1 best_model_epoch = -1
for i in range(num_epochs): start_epoch = self.completed_epochs
for i in range(start_epoch, num_epochs):
records = list() records = list()
step_start_time = time.time() step_start_time = time.time()
epoch_start_time = time.time() epoch_start_time = time.time()
...@@ -498,6 +519,7 @@ class BaseAPI: ...@@ -498,6 +519,7 @@ class BaseAPI:
return_details=True) return_details=True)
logging.info('[EVAL] Finished, Epoch={}, {} .'.format( logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
i + 1, dict2str(self.eval_metrics))) i + 1, dict2str(self.eval_metrics)))
self.completed_epochs += 1
# 保存最优模型 # 保存最优模型
best_accuracy_key = list(self.eval_metrics.keys())[0] best_accuracy_key = list(self.eval_metrics.keys())[0]
current_accuracy = self.eval_metrics[best_accuracy_key] current_accuracy = self.eval_metrics[best_accuracy_key]
......
...@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI): ...@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI): ...@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
...@@ -160,8 +162,8 @@ class BaseClassifier(BaseAPI): ...@@ -160,8 +162,8 @@ class BaseClassifier(BaseAPI):
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
save_dir=save_dir, save_dir=save_dir,
sensitivities_file=sensitivities_file, sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss) eval_metric_loss=eval_metric_loss,
resume_checkpoint=resume_checkpoint)
# 训练 # 训练
self.train_loop( self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
......
...@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI): ...@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI): ...@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
...@@ -284,7 +286,8 @@ class DeepLabv3p(BaseAPI): ...@@ -284,7 +286,8 @@ class DeepLabv3p(BaseAPI):
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
save_dir=save_dir, save_dir=save_dir,
sensitivities_file=sensitivities_file, sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss) eval_metric_loss=eval_metric_loss,
resume_checkpoint=resume_checkpoint)
# 训练 # 训练
self.train_loop( self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
...@@ -405,5 +408,6 @@ class DeepLabv3p(BaseAPI): ...@@ -405,5 +408,6 @@ class DeepLabv3p(BaseAPI):
w, h = info[1][1], info[1][0] w, h = info[1][1], info[1][0]
pred = pred[0:h, 0:w] pred = pred[0:h, 0:w]
else: else:
raise Exception("Unexpected info '{}' in im_info".format(info[0])) raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
return {'label_map': pred, 'score_map': result[1]} return {'label_map': pred, 'score_map': result[1]}
...@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI): ...@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI):
metric=None, metric=None,
use_vdl=False, use_vdl=False,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI): ...@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -231,7 +233,9 @@ class FasterRCNN(BaseAPI): ...@@ -231,7 +233,9 @@ class FasterRCNN(BaseAPI):
startup_prog=fluid.default_startup_program(), startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
fuse_bn=fuse_bn, fuse_bn=fuse_bn,
save_dir=save_dir) save_dir=save_dir,
resume_checkpoint=resume_checkpoint)
# 训练 # 训练
self.train_loop( self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
......
...@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN): ...@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN):
metric=None, metric=None,
use_vdl=False, use_vdl=False,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN): ...@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN): ...@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN):
metric = 'COCO' metric = 'COCO'
else: else:
raise Exception( raise Exception(
"train_dataset should be datasets.COCODetection or datasets.EasyDataDet.") "train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
)
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric self.metric = metric
if not self.trainable: if not self.trainable:
...@@ -197,7 +200,8 @@ class MaskRCNN(FasterRCNN): ...@@ -197,7 +200,8 @@ class MaskRCNN(FasterRCNN):
startup_prog=fluid.default_startup_program(), startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
fuse_bn=fuse_bn, fuse_bn=fuse_bn,
save_dir=save_dir) save_dir=save_dir,
resume_checkpoint=resume_checkpoint)
# 训练 # 训练
self.train_loop( self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
......
...@@ -121,7 +121,8 @@ class UNet(DeepLabv3p): ...@@ -121,7 +121,8 @@ class UNet(DeepLabv3p):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -145,14 +146,14 @@ class UNet(DeepLabv3p): ...@@ -145,14 +146,14 @@ class UNet(DeepLabv3p):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
return super( return super(UNet, self).train(
UNet, num_epochs, train_dataset, train_batch_size, eval_dataset,
self).train(num_epochs, train_dataset, train_batch_size, save_interval_epochs, log_interval_steps, save_dir,
eval_dataset, save_interval_epochs, log_interval_steps, pretrain_weights, optimizer, learning_rate, lr_decay_power,
save_dir, pretrain_weights, optimizer, learning_rate, use_vdl, sensitivities_file, eval_metric_loss, early_stop,
lr_decay_power, use_vdl, sensitivities_file, early_stop_patience, resume_checkpoint)
eval_metric_loss, early_stop, early_stop_patience)
...@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI): ...@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI): ...@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -236,7 +238,8 @@ class YOLOv3(BaseAPI): ...@@ -236,7 +238,8 @@ class YOLOv3(BaseAPI):
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
save_dir=save_dir, save_dir=save_dir,
sensitivities_file=sensitivities_file, sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss) eval_metric_loss=eval_metric_loss,
resume_checkpoint=resume_checkpoint)
# 训练 # 训练
self.train_loop( self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
......
...@@ -170,11 +170,85 @@ def load_pdparams(exe, main_prog, model_dir): ...@@ -170,11 +170,85 @@ def load_pdparams(exe, main_prog, model_dir):
len(vars_to_load), model_dir)) len(vars_to_load), model_dir))
def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False): def is_persistable(var):
import paddle.fluid as fluid
from paddle.fluid.proto.framework_pb2 import VarType
if var.desc.type() == fluid.core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == fluid.core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == fluid.core.VarDesc.VarType.READER:
return False
return var.persistable
def is_belong_to_optimizer(var):
import paddle.fluid as fluid
from paddle.fluid.proto.framework_pb2 import VarType
if not (isinstance(var, fluid.framework.Parameter)
or var.desc.need_check_feed()):
return is_persistable(var)
return False
def load_pdopt(exe, main_prog, model_dir):
import paddle.fluid as fluid
optimizer_var_list = list()
vars_to_load = list()
import pickle
with open(osp.join(model_dir, 'model.pdopt'), 'rb') as f:
opt_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
optimizer_var_list = list(
filter(is_belong_to_optimizer, main_prog.list_vars()))
exception_message = "the training process can not be resumed due to optimizer set now and last time is different. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
if len(optimizer_var_list) > 0:
for var in optimizer_var_list:
if var.name not in opt_dict:
raise Exception(
"{} is not in saved paddlex optimizer, {}".format(
var.name, exception_message))
if var.shape != opt_dict[var.name].shape:
raise Exception(
"Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
.format(var.name, opt_dict[var.name].shape,
var.shape), exception_message)
optimizer_varname_list = [var.name for var in optimizer_var_list]
for k, v in opt_dict.items():
if k not in optimizer_varname_list:
raise Exception(
"{} in saved paddlex optimizer is not in the model, {}".
format(k, exception_message))
fluid.io.set_program_state(main_prog, opt_dict)
if len(optimizer_var_list) == 0:
raise Exception(
"There is no optimizer parameters in the model, please set the optimizer!"
)
else:
logging.info(
"There are {} optimizer parameters in {} are loaded.".format(
len(optimizer_var_list), model_dir))
def load_pretrain_weights(exe,
main_prog,
weights_dir,
fuse_bn=False,
resume=False):
if not osp.exists(weights_dir): if not osp.exists(weights_dir):
raise Exception("Path {} not exists.".format(weights_dir)) raise Exception("Path {} not exists.".format(weights_dir))
if osp.exists(osp.join(weights_dir, "model.pdparams")): if osp.exists(osp.join(weights_dir, "model.pdparams")):
return load_pdparams(exe, main_prog, weights_dir) load_pdparams(exe, main_prog, weights_dir)
if resume:
if osp.exists(osp.join(weights_dir, "model.pdopt")):
load_pdopt(exe, main_prog, weights_dir)
else:
raise Exception(
"Optimizer file {} does not exist. Stop resumming training. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
.format(osp.join(weights_dir, "model.pdopt")))
return
import paddle.fluid as fluid import paddle.fluid as fluid
vars_to_load = list() vars_to_load = list()
for var in main_prog.list_vars(): for var in main_prog.list_vars():
...@@ -209,6 +283,45 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False): ...@@ -209,6 +283,45 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
len(vars_to_load), weights_dir)) len(vars_to_load), weights_dir))
if fuse_bn: if fuse_bn:
fuse_bn_weights(exe, main_prog, weights_dir) fuse_bn_weights(exe, main_prog, weights_dir)
if resume:
exception_message = "the training process can not be resumed due to optimizer set now and last time is different. Recommend to use `pretrain_weights` instead of `resume_checkpoint`"
optimizer_var_list = list(
filter(is_belong_to_optimizer, main_prog.list_vars()))
if len(optimizer_var_list) > 0:
for var in optimizer_var_list:
if not osp.exists(osp.join(weights_dir, var.name)):
raise Exception(
"Optimizer parameter {} doesn't exist, {}".format(
osp.join(weights_dir, var.name),
exception_message))
pretrained_shape = parse_param_file(
osp.join(weights_dir, var.name))
actual_shape = tuple(var.shape)
if pretrained_shape != actual_shape:
raise Exception(
"Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
.format(var.name, pretrained_shape,
actual_shape), exception_message)
optimizer_varname_list = [var.name for var in optimizer_var_list]
if os.exists(osp.join(weights_dir, 'learning_rate')
) and 'learning_rate' not in optimizer_varname_list:
raise Exception(
"Optimizer parameter {}/learning_rate is not in the model, {}"
.format(weights_dir, exception_message))
fluid.io.load_vars(
executor=exe,
dirname=weights_dir,
main_program=main_prog,
vars=optimizer_var_list)
if len(optimizer_var_list) == 0:
raise Exception(
"There is no optimizer parameters in the model, please set the optimizer!"
)
else:
logging.info(
"There are {} optimizer parameters in {} are loaded.".format(
len(optimizer_var_list), weights_dir))
class EarlyStop: class EarlyStop:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册