From 03b66704383fe863e80887f1807ab1638f1baac9 Mon Sep 17 00:00:00 2001 From: zhangxuefei Date: Tue, 30 Apr 2019 18:29:55 +0800 Subject: [PATCH] Fix the bug that lack the defination of linear_warmup_decay function --- paddlehub/finetune/optimization.py | 27 ++++++++++++++++++++++++++- paddlehub/finetune/strategy.py | 9 ++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/paddlehub/finetune/optimization.py b/paddlehub/finetune/optimization.py index d438660c..5b7363c0 100644 --- a/paddlehub/finetune/optimization.py +++ b/paddlehub/finetune/optimization.py @@ -19,6 +19,8 @@ from __future__ import print_function import numpy as np import paddle.fluid as fluid +import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler +from paddle.fluid.layers import control_flow def adam_weight_decay_optimization(loss, @@ -35,7 +37,7 @@ def adam_weight_decay_optimization(loss, warmup_steps) elif scheduler == 'linear_decay': scheduled_lr = linear_warmup_decay(learning_rate, warmup_steps, - num_train_steps) + main_program) else: raise ValueError("Unkown learning rate scheduler, should be " "'noam_decay' or 'linear_decay'") @@ -76,3 +78,26 @@ def adam_weight_decay_optimization(loss, fluid.layers.assign(output=param, input=updated_param) return scheduled_lr + + +def linear_warmup_decay(init_lr, num_warmup_steps, main_program): + with main_program._lr_schedule_guard(): + global_step = lr_scheduler._decay_step_counter() + + lr = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") + + with control_flow.Switch() as switch: + with switch.case(global_step < num_warmup_steps): + decayed_lr = init_lr * global_step * 1.0 / num_warmup_steps + fluid.layers.assign(decayed_lr, lr) + with switch.default(): + last_value_var = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=float(init_lr)) + fluid.layers.assign(last_value_var, lr) + + return lr diff --git a/paddlehub/finetune/strategy.py b/paddlehub/finetune/strategy.py index 22fdb61b..ae5e88f1 100644 --- a/paddlehub/finetune/strategy.py +++ b/paddlehub/finetune/strategy.py @@ -89,7 +89,7 @@ class AdamWeightDecayStrategy(DefaultStrategy): def __init__(self, learning_rate=1e-4, lr_scheduler="linear_decay", - warmup_proportion=0.0, + warmup_proportion=0.1, weight_decay=0.01, optimizer_name="adam"): super(AdamWeightDecayStrategy, self).__init__( @@ -118,6 +118,13 @@ class AdamWeightDecayStrategy(DefaultStrategy): # calculate wamrup step dev_count = self._get_dev_count(config) num_train_examples = data_reader.get_num_examples(phase='train') + data_reader.data_generator( + batch_size=config.batch_size, phase='train', shuffle=True) + data_reader.data_generator( + batch_size=config.batch_size, phase='val', shuffle=False) + data_reader.data_generator( + batch_size=config.batch_size, phase='dev', shuffle=False) + num_train_examples = data_reader.get_num_examples(phase='train') max_train_steps = config.num_epoch * num_train_examples // config.batch_size // dev_count warmup_steps = int(max_train_steps * self.warmup_proportion) -- GitLab