From aefe18c7a3df123940263e8558be48c648d7edb7 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 2 Jun 2020 05:50:00 +0000 Subject: [PATCH] add warmup for classifier --- docs/appendix/parameters.md | 24 ++++++++++++++++++++++ paddlex/cv/models/classifier.py | 36 +++++++++++++++++++++++++-------- paddlex/utils/logging.py | 15 +++++++------- 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/docs/appendix/parameters.md b/docs/appendix/parameters.md index 732535d..2e83468 100644 --- a/docs/appendix/parameters.md +++ b/docs/appendix/parameters.md @@ -23,3 +23,27 @@ Batch Size指模型在训练过程中,一次性处理的样本数量, 如若 - [实例分割MaskRCNN-train](https://paddlex.readthedocs.io/zh_CN/latest/apis/models/instance_segmentation.html#train) - [语义分割DeepLabv3p-train](https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#train) - [语义分割UNet](https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#id2) + +## 关于lr_decay_epoch, warmup_steps等参数的说明 + +在PaddleX或其它深度学习模型的训练过程中,经常见到lr_decay_epoch, warmup_steps, warmup_start_lr等参数设置,下面介绍一些这些参数的作用。 + +首先这些参数都是用于控制模型训练过程中学习率的变化方式,例如我们在训练时将learning_rate设为0.1, 通常情况,在模型的训练过程中,学习率一直以0.1不变训练下去, 但为了调出更好的模型效果,我们往往不希望学习率一直保持不变。 + +### warmup_steps和warmup_start_lr + +我们在训练模型时,一般都会使用预训练模型,例如检测模型在训练时使用backbone在ImageNet数据集上的预训练权重。但由于在自行训练时,自己的数据与ImageNet数据集存在较大的差异,可能会一开始由于梯度过大使得训练出现问题,因此可以在刚开始训练时,让学习率以一个较小的值,慢慢增长到设定的学习率。因此`warmup_steps`和`warmup_start_lr`就是这个作用,模型开始训练时,学习率会从`warmup_start_lr`开始,在`warmup_steps`内线性增长到设定的学习率。 + +### lr_decay_epochs和lr_decay_gamma + +`lr_decay_epochs`用于让学习率在模型训练后期逐步衰减,它一般是一个list,如[6, 8, 10],表示学习率在第6个epoch时衰减一次,第8个epoch时再衰减一次,第10个epoch时再衰减一次。每次学习率衰减为之前的学习率*lr_decay_gamma + +### PaddleX中对warmup_steps和lr_decay_epochs的约束限制 + +在PaddleX中,限制warmup需要在第一个学习率decay衰减前结束,因此要满足下面的公式 +``` +warmup_steps <= lr_decay_epochs[0] * num_steps_each_epoch +``` +其中公式中`num_steps_each_epoch = num_samples_in_train_dataset // train_batch_size`。 + +> 因此如若在训练时PaddleX提示`warmup_steps should be less than xxx`时,即可根据上述公式来调整你的`lr_decay_epochs`或者是`warmup_steps`使得两个参数满足上面的条件 diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index ab746dd..1d3752b 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -52,8 +52,7 @@ class BaseClassifier(BaseAPI): input_shape = [ None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] ] - image = fluid.data( - dtype='float32', shape=input_shape, name='image') + image = fluid.data(dtype='float32', shape=input_shape, name='image') else: image = fluid.data( dtype='float32', shape=[None, 3, None, None], name='image') @@ -81,7 +80,8 @@ class BaseClassifier(BaseAPI): del outputs['loss'] return inputs, outputs - def default_optimizer(self, learning_rate, lr_decay_epochs, lr_decay_gamma, + def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr, + lr_decay_epochs, lr_decay_gamma, num_steps_each_epoch): boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs] values = [ @@ -90,6 +90,22 @@ class BaseClassifier(BaseAPI): ] lr_decay = fluid.layers.piecewise_decay( boundaries=boundaries, values=values) + if warmup_steps > 0: + if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch: + logging.error( + "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset", + exit=False) + logging.error( + "See this doc for more information: xxxx", exit=False) + logging.error( + "warmup_steps should less than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function". + format(lr_decay_epochs[0] * num_steps_each_epoch)) + + lr_decay = fluid.layers.linear_lr_warmup( + learning_rate=lr_decay, + warmup_steps=warmup_steps, + start_lr=warmup_start_lr, + end_lr=learning_rate) optimizer = fluid.optimizer.Momentum( lr_decay, momentum=0.9, @@ -107,6 +123,8 @@ class BaseClassifier(BaseAPI): pretrain_weights='IMAGENET', optimizer=None, learning_rate=0.025, + warmup_steps=0, + warmup_start_lr=0.0, lr_decay_epochs=[30, 60, 90], lr_decay_gamma=0.1, use_vdl=False, @@ -129,6 +147,8 @@ class BaseClassifier(BaseAPI): optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器: fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。 learning_rate (float): 默认优化器的初始学习率。默认为0.025。 + warmup_steps(int): 学习率从warmup_start_lr上升至设定的learning_rate,所需的步数,默认为0 + warmup_start_lr(float): 学习率在warmup阶段时的起始值,默认为0.0 lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[30, 60, 90]。 lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。 use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。 @@ -149,6 +169,8 @@ class BaseClassifier(BaseAPI): num_steps_each_epoch = train_dataset.num_samples // train_batch_size optimizer = self.default_optimizer( learning_rate=learning_rate, + warmup_steps=warmup_steps, + warmup_start_lr=warmup_start_lr, lr_decay_epochs=lr_decay_epochs, lr_decay_gamma=lr_decay_gamma, num_steps_each_epoch=num_steps_each_epoch) @@ -193,8 +215,7 @@ class BaseClassifier(BaseAPI): tuple (metrics, eval_details): 当return_details为True时,增加返回dict, 包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。 """ - self.arrange_transforms( - transforms=eval_dataset.transforms, mode='eval') + self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval') data_generator = eval_dataset.generator( batch_size=batch_size, drop_last=False) k = min(5, self.num_classes) @@ -206,9 +227,8 @@ class BaseClassifier(BaseAPI): self.test_prog).with_data_parallel( share_vars_from=self.parallel_train_prog) batch_size_each_gpu = self._get_single_card_bs(batch_size) - logging.info( - "Start to evaluating(total_samples={}, total_steps={})...".format( - eval_dataset.num_samples, total_steps)) + logging.info("Start to evaluating(total_samples={}, total_steps={})...". + format(eval_dataset.num_samples, total_steps)) for step, data in tqdm.tqdm( enumerate(data_generator()), total=total_steps): images = np.array([d[0] for d in data]).astype('float32') diff --git a/paddlex/utils/logging.py b/paddlex/utils/logging.py index c118a28..adfcea5 100644 --- a/paddlex/utils/logging.py +++ b/paddlex/utils/logging.py @@ -29,13 +29,11 @@ def log(level=2, message="", use_color=False): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) if paddlex.log_level >= level: if use_color: - print("\033[1;31;40m{} [{}]\t{}\033[0m".format( - current_time, levels[level], - message).encode("utf-8").decode("latin1")) + print("\033[1;31;40m{} [{}]\t{}\033[0m".format(current_time, levels[ + level], message).encode("utf-8").decode("latin1")) else: - print( - "{} [{}]\t{}".format(current_time, levels[level], - message).encode("utf-8").decode("latin1")) + print("{} [{}]\t{}".format(current_time, levels[level], message) + .encode("utf-8").decode("latin1")) sys.stdout.flush() @@ -51,6 +49,7 @@ def warning(message="", use_color=True): log(level=1, message=message, use_color=use_color) -def error(message="", use_color=True): +def error(message="", use_color=True, exit=True): log(level=0, message=message, use_color=use_color) - sys.exit(-1) + if exit: + sys.exit(-1) -- GitLab