From 31bf269ed7b8d14e0567ffa773fe9dd263250fb4 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Tue, 12 May 2020 16:13:10 +0800 Subject: [PATCH] add resume training from a checkpoint --- docs/apis/models.md | 18 ++++++++++++------ paddlex/cv/models/base.py | 14 +++++++++++++- paddlex/cv/models/classifier.py | 28 +++++++++++++++++++++------- paddlex/cv/models/deeplabv3p.py | 32 ++++++++++++++++++++++++-------- paddlex/cv/models/faster_rcnn.py | 27 +++++++++++++++++++++------ paddlex/cv/models/mask_rcnn.py | 30 +++++++++++++++++++++++------- paddlex/cv/models/unet.py | 17 +++++++++-------- paddlex/cv/models/yolo_v3.py | 29 ++++++++++++++++++++++------- 8 files changed, 145 insertions(+), 50 deletions(-) diff --git a/docs/apis/models.md b/docs/apis/models.md index e0a3547..a76dbbe 100644 --- a/docs/apis/models.md +++ b/docs/apis/models.md @@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000) #### 分类器训练函数接口 > ```python -> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5) +> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None) > ``` > > **参数:** @@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000) > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 +> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 #### 分类器评估函数接口 @@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ #### YOLOv3训练函数接口 > ```python -> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5) +> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None) > ``` > > **参数:** @@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 +> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 #### YOLOv3评估函数接口 @@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec #### FasterRCNN训练函数接口 > ```python -> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5) +> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None) > > ``` > @@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 +> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 #### FasterRCNN评估函数接口 @@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ #### MaskRCNN训练函数接口 > ```python -> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5) +> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None) > > ``` > @@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 +> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 #### MaskRCNN评估函数接口 @@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride #### DeepLabv3训练函数接口 > ```python -> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5): +> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None): > > ``` > @@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 +> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 #### DeepLabv3评估函数接口 @@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us #### Unet训练函数接口 > ```python -> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5): +> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None): > ``` > > **参数:** @@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 +> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 #### Unet评估函数接口 diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index 234193e..246c1e5 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -213,6 +213,17 @@ class BaseAPI: prune_program(self, prune_params_ratios) self.status = 'Prune' + def resume_checkpoint(self, path, startup_prog=None): + if not osp.isdir(path): + raise Exception("Model pretrain path {} does not " + "exists.".format(path)) + if osp.exists(osp.join(path, 'model.pdparams')): + path = osp.join(path, 'model') + if startup_prog is None: + startup_prog = fluid.default_startup_program() + self.exe.run(startup_prog) + fluid.load(self.train_prog, path, executor=self.exe) + def get_model_info(self): info = dict() info['version'] = paddlex.__version__ @@ -334,6 +345,7 @@ class BaseAPI: num_epochs, train_dataset, train_batch_size, + start_epoch=0, eval_dataset=None, save_interval_epochs=1, log_interval_steps=10, @@ -408,7 +420,7 @@ class BaseAPI: best_accuracy_key = "" best_accuracy = -1.0 best_model_epoch = 1 - for i in range(num_epochs): + for i in range(start_epoch, num_epochs): records = list() step_start_time = time.time() epoch_start_time = time.time() diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index 49a1f0f..aaf439b 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI): sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, - early_stop_patience=5): + early_stop_patience=5, + resume_checkpoint=None): """训练。 Args: @@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI): early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 Raises: ValueError: 模型从inference model进行加载。 @@ -155,15 +157,27 @@ class BaseClassifier(BaseAPI): # 构建训练、验证、预测网络 self.build_program() # 初始化网络权重 - self.net_initialize( - startup_prog=fluid.default_startup_program(), - pretrain_weights=pretrain_weights, - save_dir=save_dir, - sensitivities_file=sensitivities_file, - eval_metric_loss=eval_metric_loss) + if resume_checkpoint: + self.resume_checkpoint( + path=resume_checkpoint, + startup_prog=fluid.default_startup_program()) + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + start_epoch = step // num_steps_each_epoch + 1 + else: + self.net_initialize( + startup_prog=fluid.default_startup_program(), + pretrain_weights=pretrain_weights, + save_dir=save_dir, + sensitivities_file=sensitivities_file, + eval_metric_loss=eval_metric_loss) + start_epoch = 0 # 训练 self.train_loop( + start_epoch=start_epoch, num_epochs=num_epochs, train_dataset=train_dataset, train_batch_size=train_batch_size, diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py index ce523b5..f1ccf10 100644 --- a/paddlex/cv/models/deeplabv3p.py +++ b/paddlex/cv/models/deeplabv3p.py @@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI): sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, - early_stop_patience=5): + early_stop_patience=5, + resume_checkpoint=None): """训练。 Args: @@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI): early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 Raises: ValueError: 模型从inference model进行加载。 @@ -279,14 +281,27 @@ class DeepLabv3p(BaseAPI): # 构建训练、验证、预测网络 self.build_program() # 初始化网络权重 - self.net_initialize( - startup_prog=fluid.default_startup_program(), - pretrain_weights=pretrain_weights, - save_dir=save_dir, - sensitivities_file=sensitivities_file, - eval_metric_loss=eval_metric_loss) + if resume_checkpoint: + self.resume_checkpoint( + path=resume_checkpoint, + startup_prog=fluid.default_startup_program()) + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + start_epoch = step // num_steps_each_epoch + 1 + else: + self.net_initialize( + startup_prog=fluid.default_startup_program(), + pretrain_weights=pretrain_weights, + save_dir=save_dir, + sensitivities_file=sensitivities_file, + eval_metric_loss=eval_metric_loss) + start_epoch = 0 + # 训练 self.train_loop( + start_epoch=start_epoch, num_epochs=num_epochs, train_dataset=train_dataset, train_batch_size=train_batch_size, @@ -405,5 +420,6 @@ class DeepLabv3p(BaseAPI): w, h = info[1][1], info[1][0] pred = pred[0:h, 0:w] else: - raise Exception("Unexpected info '{}' in im_info".format(info[0])) + raise Exception("Unexpected info '{}' in im_info".format( + info[0])) return {'label_map': pred, 'score_map': result[1]} diff --git a/paddlex/cv/models/faster_rcnn.py b/paddlex/cv/models/faster_rcnn.py index f5b7e94..1d97d84 100644 --- a/paddlex/cv/models/faster_rcnn.py +++ b/paddlex/cv/models/faster_rcnn.py @@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI): metric=None, use_vdl=False, early_stop=False, - early_stop_patience=5): + early_stop_patience=5, + resume_checkpoint=None): """训练。 Args: @@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI): early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 Raises: ValueError: 评估类型不在指定列表中。 @@ -227,13 +229,26 @@ class FasterRCNN(BaseAPI): fuse_bn = True if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']: fuse_bn = False - self.net_initialize( - startup_prog=fluid.default_startup_program(), - pretrain_weights=pretrain_weights, - fuse_bn=fuse_bn, - save_dir=save_dir) + if resume_checkpoint: + self.resume_checkpoint( + path=resume_checkpoint, + startup_prog=fluid.default_startup_program()) + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + start_epoch = step // num_steps_each_epoch + 1 + else: + self.net_initialize( + startup_prog=fluid.default_startup_program(), + pretrain_weights=pretrain_weights, + fuse_bn=fuse_bn, + save_dir=save_dir) + start_epoch = 0 + # 训练 self.train_loop( + start_epoch=start_epoch, num_epochs=num_epochs, train_dataset=train_dataset, train_batch_size=train_batch_size, diff --git a/paddlex/cv/models/mask_rcnn.py b/paddlex/cv/models/mask_rcnn.py index 77b2bd3..fd75df4 100644 --- a/paddlex/cv/models/mask_rcnn.py +++ b/paddlex/cv/models/mask_rcnn.py @@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN): metric=None, use_vdl=False, early_stop=False, - early_stop_patience=5): + early_stop_patience=5, + resume_checkpoint=None): """训练。 Args: @@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN): early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 Raises: ValueError: 评估类型不在指定列表中。 @@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN): metric = 'COCO' else: raise Exception( - "train_dataset should be datasets.COCODetection or datasets.EasyDataDet.") + "train_dataset should be datasets.COCODetection or datasets.EasyDataDet." + ) assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" self.metric = metric if not self.trainable: @@ -193,13 +196,26 @@ class MaskRCNN(FasterRCNN): fuse_bn = True if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']: fuse_bn = False - self.net_initialize( - startup_prog=fluid.default_startup_program(), - pretrain_weights=pretrain_weights, - fuse_bn=fuse_bn, - save_dir=save_dir) + if resume_checkpoint: + self.resume_checkpoint( + path=resume_checkpoint, + startup_prog=fluid.default_startup_program()) + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + start_epoch = step // num_steps_each_epoch + 1 + else: + self.net_initialize( + startup_prog=fluid.default_startup_program(), + pretrain_weights=pretrain_weights, + fuse_bn=fuse_bn, + save_dir=save_dir) + start_epoch = 0 + # 训练 self.train_loop( + start_epoch=start_epoch, num_epochs=num_epochs, train_dataset=train_dataset, train_batch_size=train_batch_size, diff --git a/paddlex/cv/models/unet.py b/paddlex/cv/models/unet.py index d7ce60c..d7bf80e 100644 --- a/paddlex/cv/models/unet.py +++ b/paddlex/cv/models/unet.py @@ -121,7 +121,8 @@ class UNet(DeepLabv3p): sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, - early_stop_patience=5): + early_stop_patience=5, + resume_checkpoint=None): """训练。 Args: @@ -145,14 +146,14 @@ class UNet(DeepLabv3p): early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 Raises: ValueError: 模型从inference model进行加载。 """ - return super( - UNet, - self).train(num_epochs, train_dataset, train_batch_size, - eval_dataset, save_interval_epochs, log_interval_steps, - save_dir, pretrain_weights, optimizer, learning_rate, - lr_decay_power, use_vdl, sensitivities_file, - eval_metric_loss, early_stop, early_stop_patience) + return super(UNet, self).train( + num_epochs, train_dataset, train_batch_size, eval_dataset, + save_interval_epochs, log_interval_steps, save_dir, + pretrain_weights, optimizer, learning_rate, lr_decay_power, + use_vdl, sensitivities_file, eval_metric_loss, early_stop, + early_stop_patience, resume_checkpoint) diff --git a/paddlex/cv/models/yolo_v3.py b/paddlex/cv/models/yolo_v3.py index a205554..3ea1317 100644 --- a/paddlex/cv/models/yolo_v3.py +++ b/paddlex/cv/models/yolo_v3.py @@ -166,7 +166,8 @@ class YOLOv3(BaseAPI): sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, - early_stop_patience=5): + early_stop_patience=5, + resume_checkpoint=None): """训练。 Args: @@ -195,6 +196,7 @@ class YOLOv3(BaseAPI): early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 Raises: ValueError: 评估类型不在指定列表中。 @@ -231,14 +233,27 @@ class YOLOv3(BaseAPI): # 构建训练、验证、预测网络 self.build_program() # 初始化网络权重 - self.net_initialize( - startup_prog=fluid.default_startup_program(), - pretrain_weights=pretrain_weights, - save_dir=save_dir, - sensitivities_file=sensitivities_file, - eval_metric_loss=eval_metric_loss) + if resume_checkpoint: + self.resume_checkpoint( + path=resume_checkpoint, + startup_prog=fluid.default_startup_program()) + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + num_steps_each_epoch = train_dataset.num_samples // train_batch_size + start_epoch = step // num_steps_each_epoch + 1 + else: + self.net_initialize( + startup_prog=fluid.default_startup_program(), + pretrain_weights=pretrain_weights, + save_dir=save_dir, + sensitivities_file=sensitivities_file, + eval_metric_loss=eval_metric_loss) + start_epoch = 0 + # 训练 self.train_loop( + start_epoch=start_epoch, num_epochs=num_epochs, train_dataset=train_dataset, train_batch_size=train_batch_size, -- GitLab