提交 31bf269e 编写于 作者: F FlyingQianMM

add resume training from a checkpoint

上级 1af3e044
...@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000) ...@@ -17,7 +17,7 @@ paddlex.cls.ResNet50(num_classes=1000)
#### 分类器训练函数接口 #### 分类器训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000) ...@@ -39,6 +39,7 @@ paddlex.cls.ResNet50(num_classes=1000)
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### 分类器评估函数接口 #### 分类器评估函数接口
...@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ ...@@ -111,7 +112,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
#### YOLOv3训练函数接口 #### YOLOv3训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=8, eval_dataset=None, save_interval_epochs=20, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/8000, warmup_steps=1000, warmup_start_lr=0.0, lr_decay_epochs=[213, 240], lr_decay_gamma=0.1, metric=None, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_ ...@@ -136,6 +137,7 @@ paddlex.det.YOLOv3(num_classes=80, backbone='MobileNetV1', anchors=None, anchor_
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### YOLOv3评估函数接口 #### YOLOv3评估函数接口
...@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -190,7 +192,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
#### FasterRCNN训练函数接口 #### FasterRCNN训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, save_interval_epochs=1, log_interval_steps=2,save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.0025, warmup_steps=500, warmup_start_lr=1.0/1200, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> >
> ``` > ```
> >
...@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec ...@@ -214,6 +216,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### FasterRCNN评估函数接口 #### FasterRCNN评估函数接口
...@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -270,7 +273,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
#### MaskRCNN训练函数接口 #### MaskRCNN训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5) > train(self, num_epochs, train_dataset, train_batch_size=1, eval_dataset=None, save_interval_epochs=1, log_interval_steps=20, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=1.0/800, warmup_steps=500, warmup_start_lr=1.0 / 2400, lr_decay_epochs=[8, 11], lr_decay_gamma=0.1, metric=None, use_vdl=False, early_stop=False, early_stop_patience=5, resume_checkpoint=None)
> >
> ``` > ```
> >
...@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_ ...@@ -294,6 +297,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。 > > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认值为False。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### MaskRCNN评估函数接口 #### MaskRCNN评估函数接口
...@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -358,7 +362,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
#### DeepLabv3训练函数接口 #### DeepLabv3训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5): > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
> >
> ``` > ```
> >
...@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -380,6 +384,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### DeepLabv3评估函数接口 #### DeepLabv3评估函数接口
...@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -437,7 +442,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
#### Unet训练函数接口 #### Unet训练函数接口
> ```python > ```python
> train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5): > train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, eval_batch_size=1, save_interval_epochs=1, log_interval_steps=2, save_dir='output', pretrain_weights='COCO', optimizer=None, learning_rate=0.01, lr_decay_power=0.9, use_vdl=False, sensitivities_file=None, eval_metric_loss=0.05, early_stop=False, early_stop_patience=5, resume_checkpoint=None):
> ``` > ```
> >
> **参数:** > **参数:**
...@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -458,6 +463,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。 > > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。 > > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。 > > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### Unet评估函数接口 #### Unet评估函数接口
......
...@@ -213,6 +213,17 @@ class BaseAPI: ...@@ -213,6 +213,17 @@ class BaseAPI:
prune_program(self, prune_params_ratios) prune_program(self, prune_params_ratios)
self.status = 'Prune' self.status = 'Prune'
def resume_checkpoint(self, path, startup_prog=None):
if not osp.isdir(path):
raise Exception("Model pretrain path {} does not "
"exists.".format(path))
if osp.exists(osp.join(path, 'model.pdparams')):
path = osp.join(path, 'model')
if startup_prog is None:
startup_prog = fluid.default_startup_program()
self.exe.run(startup_prog)
fluid.load(self.train_prog, path, executor=self.exe)
def get_model_info(self): def get_model_info(self):
info = dict() info = dict()
info['version'] = paddlex.__version__ info['version'] = paddlex.__version__
...@@ -334,6 +345,7 @@ class BaseAPI: ...@@ -334,6 +345,7 @@ class BaseAPI:
num_epochs, num_epochs,
train_dataset, train_dataset,
train_batch_size, train_batch_size,
start_epoch=0,
eval_dataset=None, eval_dataset=None,
save_interval_epochs=1, save_interval_epochs=1,
log_interval_steps=10, log_interval_steps=10,
...@@ -408,7 +420,7 @@ class BaseAPI: ...@@ -408,7 +420,7 @@ class BaseAPI:
best_accuracy_key = "" best_accuracy_key = ""
best_accuracy = -1.0 best_accuracy = -1.0
best_model_epoch = 1 best_model_epoch = 1
for i in range(num_epochs): for i in range(start_epoch, num_epochs):
records = list() records = list()
step_start_time = time.time() step_start_time = time.time()
epoch_start_time = time.time() epoch_start_time = time.time()
......
...@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI): ...@@ -112,7 +112,8 @@ class BaseClassifier(BaseAPI):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI): ...@@ -137,6 +138,7 @@ class BaseClassifier(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
...@@ -155,15 +157,27 @@ class BaseClassifier(BaseAPI): ...@@ -155,15 +157,27 @@ class BaseClassifier(BaseAPI):
# 构建训练、验证、预测网络 # 构建训练、验证、预测网络
self.build_program() self.build_program()
# 初始化网络权重 # 初始化网络权重
self.net_initialize( if resume_checkpoint:
startup_prog=fluid.default_startup_program(), self.resume_checkpoint(
pretrain_weights=pretrain_weights, path=resume_checkpoint,
save_dir=save_dir, startup_prog=fluid.default_startup_program())
sensitivities_file=sensitivities_file, scope = fluid.global_scope()
eval_metric_loss=eval_metric_loss) v = scope.find_var('@LR_DECAY_COUNTER@')
step = np.array(v.get_tensor())[0] if v else 0
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
start_epoch = step // num_steps_each_epoch + 1
else:
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
save_dir=save_dir,
sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss)
start_epoch = 0
# 训练 # 训练
self.train_loop( self.train_loop(
start_epoch=start_epoch,
num_epochs=num_epochs, num_epochs=num_epochs,
train_dataset=train_dataset, train_dataset=train_dataset,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
......
...@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI): ...@@ -234,7 +234,8 @@ class DeepLabv3p(BaseAPI):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI): ...@@ -258,6 +259,7 @@ class DeepLabv3p(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
...@@ -279,14 +281,27 @@ class DeepLabv3p(BaseAPI): ...@@ -279,14 +281,27 @@ class DeepLabv3p(BaseAPI):
# 构建训练、验证、预测网络 # 构建训练、验证、预测网络
self.build_program() self.build_program()
# 初始化网络权重 # 初始化网络权重
self.net_initialize( if resume_checkpoint:
startup_prog=fluid.default_startup_program(), self.resume_checkpoint(
pretrain_weights=pretrain_weights, path=resume_checkpoint,
save_dir=save_dir, startup_prog=fluid.default_startup_program())
sensitivities_file=sensitivities_file, scope = fluid.global_scope()
eval_metric_loss=eval_metric_loss) v = scope.find_var('@LR_DECAY_COUNTER@')
step = np.array(v.get_tensor())[0] if v else 0
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
start_epoch = step // num_steps_each_epoch + 1
else:
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
save_dir=save_dir,
sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss)
start_epoch = 0
# 训练 # 训练
self.train_loop( self.train_loop(
start_epoch=start_epoch,
num_epochs=num_epochs, num_epochs=num_epochs,
train_dataset=train_dataset, train_dataset=train_dataset,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
...@@ -405,5 +420,6 @@ class DeepLabv3p(BaseAPI): ...@@ -405,5 +420,6 @@ class DeepLabv3p(BaseAPI):
w, h = info[1][1], info[1][0] w, h = info[1][1], info[1][0]
pred = pred[0:h, 0:w] pred = pred[0:h, 0:w]
else: else:
raise Exception("Unexpected info '{}' in im_info".format(info[0])) raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
return {'label_map': pred, 'score_map': result[1]} return {'label_map': pred, 'score_map': result[1]}
...@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI): ...@@ -167,7 +167,8 @@ class FasterRCNN(BaseAPI):
metric=None, metric=None,
use_vdl=False, use_vdl=False,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI): ...@@ -193,6 +194,7 @@ class FasterRCNN(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -227,13 +229,26 @@ class FasterRCNN(BaseAPI): ...@@ -227,13 +229,26 @@ class FasterRCNN(BaseAPI):
fuse_bn = True fuse_bn = True
if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']: if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']:
fuse_bn = False fuse_bn = False
self.net_initialize( if resume_checkpoint:
startup_prog=fluid.default_startup_program(), self.resume_checkpoint(
pretrain_weights=pretrain_weights, path=resume_checkpoint,
fuse_bn=fuse_bn, startup_prog=fluid.default_startup_program())
save_dir=save_dir) scope = fluid.global_scope()
v = scope.find_var('@LR_DECAY_COUNTER@')
step = np.array(v.get_tensor())[0] if v else 0
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
start_epoch = step // num_steps_each_epoch + 1
else:
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
fuse_bn=fuse_bn,
save_dir=save_dir)
start_epoch = 0
# 训练 # 训练
self.train_loop( self.train_loop(
start_epoch=start_epoch,
num_epochs=num_epochs, num_epochs=num_epochs,
train_dataset=train_dataset, train_dataset=train_dataset,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
......
...@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN): ...@@ -132,7 +132,8 @@ class MaskRCNN(FasterRCNN):
metric=None, metric=None,
use_vdl=False, use_vdl=False,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN): ...@@ -158,6 +159,7 @@ class MaskRCNN(FasterRCNN):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN): ...@@ -169,7 +171,8 @@ class MaskRCNN(FasterRCNN):
metric = 'COCO' metric = 'COCO'
else: else:
raise Exception( raise Exception(
"train_dataset should be datasets.COCODetection or datasets.EasyDataDet.") "train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
)
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric self.metric = metric
if not self.trainable: if not self.trainable:
...@@ -193,13 +196,26 @@ class MaskRCNN(FasterRCNN): ...@@ -193,13 +196,26 @@ class MaskRCNN(FasterRCNN):
fuse_bn = True fuse_bn = True
if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']: if self.with_fpn and self.backbone in ['ResNet18', 'ResNet50']:
fuse_bn = False fuse_bn = False
self.net_initialize( if resume_checkpoint:
startup_prog=fluid.default_startup_program(), self.resume_checkpoint(
pretrain_weights=pretrain_weights, path=resume_checkpoint,
fuse_bn=fuse_bn, startup_prog=fluid.default_startup_program())
save_dir=save_dir) scope = fluid.global_scope()
v = scope.find_var('@LR_DECAY_COUNTER@')
step = np.array(v.get_tensor())[0] if v else 0
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
start_epoch = step // num_steps_each_epoch + 1
else:
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
fuse_bn=fuse_bn,
save_dir=save_dir)
start_epoch = 0
# 训练 # 训练
self.train_loop( self.train_loop(
start_epoch=start_epoch,
num_epochs=num_epochs, num_epochs=num_epochs,
train_dataset=train_dataset, train_dataset=train_dataset,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
......
...@@ -121,7 +121,8 @@ class UNet(DeepLabv3p): ...@@ -121,7 +121,8 @@ class UNet(DeepLabv3p):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -145,14 +146,14 @@ class UNet(DeepLabv3p): ...@@ -145,14 +146,14 @@ class UNet(DeepLabv3p):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
return super( return super(UNet, self).train(
UNet, num_epochs, train_dataset, train_batch_size, eval_dataset,
self).train(num_epochs, train_dataset, train_batch_size, save_interval_epochs, log_interval_steps, save_dir,
eval_dataset, save_interval_epochs, log_interval_steps, pretrain_weights, optimizer, learning_rate, lr_decay_power,
save_dir, pretrain_weights, optimizer, learning_rate, use_vdl, sensitivities_file, eval_metric_loss, early_stop,
lr_decay_power, use_vdl, sensitivities_file, early_stop_patience, resume_checkpoint)
eval_metric_loss, early_stop, early_stop_patience)
...@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI): ...@@ -166,7 +166,8 @@ class YOLOv3(BaseAPI):
sensitivities_file=None, sensitivities_file=None,
eval_metric_loss=0.05, eval_metric_loss=0.05,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5,
resume_checkpoint=None):
"""训练。 """训练。
Args: Args:
...@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI): ...@@ -195,6 +196,7 @@ class YOLOv3(BaseAPI):
early_stop (bool): 是否使用提前终止训练策略。默认值为False。 early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。 连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises: Raises:
ValueError: 评估类型不在指定列表中。 ValueError: 评估类型不在指定列表中。
...@@ -231,14 +233,27 @@ class YOLOv3(BaseAPI): ...@@ -231,14 +233,27 @@ class YOLOv3(BaseAPI):
# 构建训练、验证、预测网络 # 构建训练、验证、预测网络
self.build_program() self.build_program()
# 初始化网络权重 # 初始化网络权重
self.net_initialize( if resume_checkpoint:
startup_prog=fluid.default_startup_program(), self.resume_checkpoint(
pretrain_weights=pretrain_weights, path=resume_checkpoint,
save_dir=save_dir, startup_prog=fluid.default_startup_program())
sensitivities_file=sensitivities_file, scope = fluid.global_scope()
eval_metric_loss=eval_metric_loss) v = scope.find_var('@LR_DECAY_COUNTER@')
step = np.array(v.get_tensor())[0] if v else 0
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
start_epoch = step // num_steps_each_epoch + 1
else:
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
save_dir=save_dir,
sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss)
start_epoch = 0
# 训练 # 训练
self.train_loop( self.train_loop(
start_epoch=start_epoch,
num_epochs=num_epochs, num_epochs=num_epochs,
train_dataset=train_dataset, train_dataset=train_dataset,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册