未验证 提交 407de039 编写于 作者: Z Zhou Wei 提交者: GitHub

[2.0API] Reconstruct all API related to LR Scheduler, unify dygraph and static (#26550)

* Reconstruct all API related to lr scheduler, unify dygraph and static

* Reconstruct all API related to lr scheduler, unify dygraph and static

* fix doc

* fix doc

* fix doc of lr_scheduler

* fix unittest and english doc

* fix english doc

* fix confilt

* fix doc
上级 6e823cfe
...@@ -850,6 +850,7 @@ class Executor(object): ...@@ -850,6 +850,7 @@ class Executor(object):
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy, return_merged): return_numpy, return_merged):
from paddle.optimizer.lr_scheduler import _LRScheduler
exe = program._executor exe = program._executor
# TODO(zhenghuihuang): quantization uses Graph in CompiledProgram # TODO(zhenghuihuang): quantization uses Graph in CompiledProgram
# instead of program. We will add support for checking Vars in Graph # instead of program. We will add support for checking Vars in Graph
...@@ -893,6 +894,16 @@ class Executor(object): ...@@ -893,6 +894,16 @@ class Executor(object):
res.append(res_dict) res.append(res_dict)
exe.feed_tensors_into_local_scopes(res) exe.feed_tensors_into_local_scopes(res)
if hasattr(program._program, 'lr_sheduler'):
lr_sheduler = program._program.lr_sheduler
assert isinstance(lr_sheduler, _LRScheduler), "must be _LRScheduler"
lr_value = lr_sheduler()
lr_var = program._program.global_block().vars[lr_sheduler._var_name]
lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype)
exe.feed_and_split_tensor_into_local_scopes({
lr_sheduler._var_name: lr_tensor
})
fetch_var_names = list(map(_to_name_str, fetch_list)) fetch_var_names = list(map(_to_name_str, fetch_list))
tensors = exe.run(fetch_var_names, return_merged)._move_to_list() tensors = exe.run(fetch_var_names, return_merged)._move_to_list()
return as_numpy(tensors) if return_numpy else tensors return as_numpy(tensors) if return_numpy else tensors
...@@ -1222,7 +1233,7 @@ class Executor(object): ...@@ -1222,7 +1233,7 @@ class Executor(object):
def _run_program(self, program, feed, fetch_list, feed_var_name, def _run_program(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache): fetch_var_name, scope, return_numpy, use_program_cache):
from paddle.optimizer.lr_scheduler import _LRScheduler
if feed is None: if feed is None:
feed = {} feed = {}
elif isinstance(feed, (list, tuple)): elif isinstance(feed, (list, tuple)):
...@@ -1278,6 +1289,16 @@ class Executor(object): ...@@ -1278,6 +1289,16 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
assert isinstance(program.lr_sheduler,
_LRScheduler), "must be _LRScheduler"
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope, lr_sheduler._var_name)
tensor.set(data, self.place)
if not use_program_cache: if not use_program_cache:
self._default_executor.run(program.desc, scope, 0, True, True, self._default_executor.run(program.desc, scope, 0, True, True,
fetch_var_name) fetch_var_name)
......
...@@ -4450,6 +4450,8 @@ class Program(object): ...@@ -4450,6 +4450,8 @@ class Program(object):
p._current_role = self._current_role p._current_role = self._current_role
p.__op_role_var = self.__op_role_var p.__op_role_var = self.__op_role_var
p._appending_grad_times = self._appending_grad_times p._appending_grad_times = self._appending_grad_times
if hasattr(self, 'lr_sheduler'):
p.lr_sheduler = self.lr_sheduler
#NOTE(zhiqiu): we sync the cloned program, to update its program by #NOTE(zhiqiu): we sync the cloned program, to update its program by
# its desc. # its desc.
......
...@@ -68,14 +68,16 @@ class Optimizer(object): ...@@ -68,14 +68,16 @@ class Optimizer(object):
regularization=None, regularization=None,
grad_clip=None, grad_clip=None,
name=None): name=None):
# Because of the loop import, so place it in the function body
from paddle.optimizer.lr_scheduler import _LRScheduler
self._parameter_list = list( self._parameter_list = list(
parameter_list) if parameter_list is not None else None parameter_list) if parameter_list is not None else None
self._name = name self._name = name
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate,
not isinstance(learning_rate, LearningRateDecay): (float, LearningRateDecay, _LRScheduler)):
raise TypeError( raise TypeError(
"learning rate should be float or LearningRateDecay, got %s here" "learning rate should be float or _LRScheduler, got %s here"
% type(learning_rate)) % type(learning_rate))
if self._parameter_list is None: if self._parameter_list is None:
raise AttributeError( raise AttributeError(
...@@ -90,11 +92,11 @@ class Optimizer(object): ...@@ -90,11 +92,11 @@ class Optimizer(object):
% regularization.__str__()) % regularization.__str__())
break break
else: else:
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate,
not isinstance(learning_rate, framework.Variable): (float, framework.Variable, _LRScheduler)):
raise TypeError( raise TypeError(
"learning rate should be float or Variable, got %s here" % "learning rate should be float or _LRScheduler, got %s here"
type(learning_rate)) % type(learning_rate))
if grad_clip is not None: if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase): if not isinstance(grad_clip, GradientClipBase):
...@@ -144,11 +146,15 @@ class Optimizer(object): ...@@ -144,11 +146,15 @@ class Optimizer(object):
state_dict = adam.state_dict() state_dict = adam.state_dict()
''' '''
from paddle.optimizer.lr_scheduler import _LRScheduler
state_dict = {} state_dict = {}
for k, v in self._accumulators.items(): for k, v in self._accumulators.items():
for para_name, var_tmp in v.items(): for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp state_dict[var_tmp.name] = var_tmp
# global step if use lr decay # global step if use lr decay
if isinstance(self._learning_rate, _LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
return state_dict
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict() state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
...@@ -192,6 +198,9 @@ class Optimizer(object): ...@@ -192,6 +198,9 @@ class Optimizer(object):
adam.set_dict(opti_state_dict) adam.set_dict(opti_state_dict)
''' '''
from paddle.optimizer.lr_scheduler import _LRScheduler
if isinstance(self._learning_rate, _LRScheduler):
self._learning_rate.set_dict(state_dict["LR_Scheduler"])
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate.set_dict(state_dict["LR_Scheduler"]) self._learning_rate.set_dict(state_dict["LR_Scheduler"])
...@@ -252,6 +261,30 @@ class Optimizer(object): ...@@ -252,6 +261,30 @@ class Optimizer(object):
return self._opti_name_list return self._opti_name_list
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
from paddle.optimizer.lr_scheduler import _LRScheduler
if isinstance(self._learning_rate, _LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype='float32' if self._dtype is None else self._dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
self._learning_rate_map[framework.default_main_program(
)] = lr_var
lr_value = float(self._learning_rate())
self.helper.set_variable_initializer(
lr_var, initializer=Constant(value=lr_value))
return
if imperative_base.enabled(): if imperative_base.enabled():
# create learning rate Variable # create learning rate Variable
if isinstance(self._learning_rate, float): if isinstance(self._learning_rate, float):
......
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import numpy as np import numpy as np
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
...@@ -553,79 +554,459 @@ def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss, ...@@ -553,79 +554,459 @@ def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
class TestReduceLROnPlateauDecay(unittest.TestCase): class TestReduceLROnPlateauDecay(unittest.TestCase):
def test_dygraph_mode(self): def test_ReduceLR(self):
with fluid.dygraph.guard(): # the decay rate must be less than 1.0
# the decay rate must be less than 1.0 with self.assertRaises(ValueError):
with self.assertRaises(ValueError): paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=2.0)
fluid.dygraph.ReduceLROnPlateau( # the mode must be "min" or "max"
learning_rate=1.0, decay_rate=2.0) with self.assertRaises(ValueError):
# the mode must be "min" or "max" paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, mode="test")
with self.assertRaises(ValueError): # the threshold_mode must be "rel" or "abs"
fluid.dygraph.ReduceLROnPlateau(learning_rate=1.0, mode="test") with self.assertRaises(ValueError):
# the threshold_mode must be "rel" or "abs" paddle.optimizer.ReduceLROnPlateau(
with self.assertRaises(ValueError): learning_rate=1.0, threshold_mode="test")
fluid.dygraph.ReduceLROnPlateau( with self.assertRaises(TypeError):
learning_rate=1.0, threshold_mode="test") paddle.optimizer.ReduceLROnPlateau(learning_rate="test")
with self.assertRaises(TypeError):
base_lr = 1.0 paddle.optimizer.ReduceLROnPlateau(learning_rate=0.5).step("test")
patience = 3
cooldown = 1
decay_rate = 0.5
threshold = 1e-4
linear = fluid.dygraph.Linear(10, 10)
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
for m, n in zip(['min', 'max', 'min', 'max'], for m, n in zip(['min', 'max', 'min', 'max'],
['rel', 'rel', 'abs', 'abs']): ['rel', 'rel', 'abs', 'abs']):
kwargs = { kwargs = {
'learning_rate': base_lr, 'learning_rate': 1.0,
'decay_rate': decay_rate,
'threshold': threshold,
'verbose': True,
'patience': patience,
'cooldown': cooldown,
'mode': m, 'mode': m,
'factor': 0.5,
'patience': 3,
'threshold': 1e-4,
'threshold_mode': n, 'threshold_mode': n,
'eps': 1e-6 'cooldown': 1,
'min_lr': 0,
'epsilon': 1e-8,
'verbose': False,
} }
print("class=" + fluid.dygraph.ReduceLROnPlateau.__name__ + paddle.enable_static()
" kwargs=" + str(kwargs)) self._test_static(place, kwargs)
lr = fluid.dygraph.ReduceLROnPlateau(**kwargs) paddle.disable_static(place)
sgd = fluid.optimizer.SGD(learning_rate=lr, self._test_dygraph(place, kwargs)
parameter_list=linear.parameters()) paddle.enable_static()
best = float("-10000") if m == "max" else float("10000") def _test_static(self, place, kwargs):
expected_lr = 1.0 paddle.enable_static()
cooldown_counter = 0
num_bad_epochs = 0 best = float("-10000") if kwargs['mode'] == "max" else float("10000")
var_list = [best, expected_lr, cooldown_counter, num_bad_epochs] current_lr = 1.0
step_num = 0 cooldown_counter = 0
epoch_num = 0 num_bad_epochs = 0
for epoch in range(30): var_list = [best, current_lr, cooldown_counter, num_bad_epochs]
total_loss = 0
main_prog = fluid.Program()
for batch_id in range(2): start_prog = fluid.Program()
step_num += 1 with fluid.program_guard(main_prog, start_prog):
x = fluid.dygraph.to_variable( x = fluid.layers.create_global_var(
np.array([step_num]).astype('float32')) [1], 1, 'float32', persistable=True)
loss = layers.sin(x) paddle.increment(x)
sgd.minimize(loss) loss = paddle.sin(x)
total_loss += loss scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
adam = fluid.optimizer.Adam(learning_rate=scheduler)
epoch_num += 1 adam.minimize(loss)
# get expected lr from fluid lr_var = adam._global_learning_rate()
avg_loss = total_loss / 1 test_prog = main_prog.clone()
lr.step(avg_loss)
actual_lr = lr().numpy()[0] exe = fluid.Executor(place)
exe.run(start_prog)
# get expected lr form python
expected_lr = reduce_lr_on_plateau(decay_rate, threshold, for epoch in range(20):
cooldown, patience, m, n, for batch_id in range(1):
avg_loss, var_list) out, actual_lr = exe.run(main_prog,
self.assertEqual( fetch_list=[loss.name, lr_var.name])
expected_lr, expected_lr = reduce_lr_on_plateau(
actual_lr, kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
msg='Failed reduce lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'. kwargs['patience'], kwargs['mode'],
format(epoch_num, expected_lr, actual_lr)) kwargs['threshold_mode'], out[0], var_list)
scheduler.step(out[0])
actual_lr = scheduler()
self.assertEqual(actual_lr, np.array(expected_lr))
for epoch in range(10):
for batch_id in range(1):
out, actual_lr = exe.run(test_prog,
fetch_list=[loss.name, lr_var.name])
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'],
kwargs['threshold_mode'], out[0], var_list)
scheduler.step(out[0])
actual_lr = scheduler()
self.assertEqual(actual_lr, np.array(expected_lr))
def _test_dygraph(self, place, kwargs):
paddle.disable_static(place)
best = float("-10000") if kwargs['mode'] == "max" else float("10000")
current_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, current_lr, cooldown_counter, num_bad_epochs]
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
sgd = paddle.optimizer.SGD(learning_rate=scheduler,
parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(1):
x = paddle.to_tensor(epoch).astype('float32')
loss = paddle.sin(x)
loss.backward()
sgd.minimize(loss)
scheduler.step(loss)
# get lr from paddle
current_lr = scheduler()
# get lr form python
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'], kwargs['threshold_mode'],
loss, var_list)
self.assertEqual(current_lr, expected_lr)
state_dict = sgd.state_dict()
scheduler1 = paddle.optimizer.ReduceLROnPlateau(**kwargs)
sgd1 = paddle.optimizer.SGD(learning_rate=scheduler1,
parameter_list=linear.parameters())
sgd1.set_dict(state_dict)
self.assertEqual(scheduler.cooldown_counter,
scheduler1.cooldown_counter)
self.assertEqual(scheduler.best.numpy()[0], scheduler1.best)
self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
def noam_lr(epoch_num, d_model, warmup_steps, learning_rate=1.0, verbose=False):
if epoch_num == 0:
a = 1
else:
a = math.pow(epoch_num, -0.5)
b = math.pow(warmup_steps, -1.5) * epoch_num
return learning_rate * math.pow(d_model, -0.5) * min(a, b)
def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
return learning_rate * lr_lambda(epoch_num)
def piecewise_lr(epoch_num, boundaries, values, verbose=False):
assert len(boundaries) + 1 == len(values)
for i in range(len(boundaries)):
if epoch_num < boundaries[i]:
return values[i]
return values[len(values) - 1]
def exponential_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate * gamma**epoch_num
def natural_exp_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate * math.exp(-1 * gamma * epoch_num)
def inverse_time_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate / (1 + gamma * epoch_num)
def polynomial_lr(epoch_num,
learning_rate,
decay_steps,
end_lr=0.0001,
power=1.0,
cycle=False,
verbose=False):
if cycle:
div = math.ceil(epoch_num / float(decay_steps))
if epoch_num == 0:
div = 1
decay_steps = decay_steps * div
else:
epoch_num = min(epoch_num, decay_steps)
return (learning_rate - end_lr) * (
(1 - float(epoch_num) / float(decay_steps))**power) + end_lr
def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
math.pi / self.T_max)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
self.last_lr - self.eta_min) + self.eta_min
cosine_annealing_lr_current = None
def cosine_annealing_lr(epoch_num,
learning_rate,
T_max,
eta_min=0,
verbose=False):
global cosine_annealing_lr_current
if epoch_num == 0:
cosine_annealing_lr_current = learning_rate
elif (epoch_num - 1 - T_max) % (2 * T_max) == 0:
cosine_annealing_lr_current = cosine_annealing_lr_current + (
learning_rate - eta_min) * (1 - math.cos(math.pi / float(T_max))
) / 2
else:
cosine_annealing_lr_current = (1 + math.cos(
math.pi * epoch_num / float(T_max))) / (1 + math.cos(math.pi * (
epoch_num - 1) / float(T_max))) * (cosine_annealing_lr_current -
eta_min) + eta_min
return cosine_annealing_lr_current
def linear_warmup_lr(epoch_num,
learning_rate,
warmup_steps,
start_lr,
end_lr,
verbose=False):
if epoch_num < warmup_steps:
return start_lr + (end_lr - start_lr) * (float(epoch_num) /
float(warmup_steps))
else:
return learning_rate
def multi_step_lr(epoch_num,
learning_rate,
milestones,
gamma=0.1,
verbose=False):
for i in range(len(milestones)):
if epoch_num < milestones[i]:
return learning_rate * (gamma**i)
return learning_rate * (gamma**len(milestones))
def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
return learning_rate * math.pow(gamma, epoch_num // step_size)
class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
main_prog = fluid.Program()
start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.data(name='x', shape=[3, 4, 5])
y = fluid.data(name='y', shape=[3, 4, 5])
z = fluid.layers.fc(x, 100)
loss = fluid.layers.mean(z)
scheduler = paddle_api(**kwarg)
adam = fluid.optimizer.Adam(learning_rate=scheduler)
adam.minimize(loss)
lr_var = adam._global_learning_rate()
test_prog = main_prog.clone()
num = 0
exe = fluid.Executor(place)
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
num += 1
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
test_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
num += 1
if isinstance(place, fluid.CPUPlace):
compiled_train_prog = fluid.CompiledProgram(
main_prog).with_data_parallel(
loss_name=loss.name, places=fluid.cpu_places(4))
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_train_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_train_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[1].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[2].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[3].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
scheduler.step()
num += 1
compiled_test_prog = fluid.CompiledProgram(
test_prog).with_data_parallel(
loss_name=loss.name,
share_vars_from=compiled_train_prog,
places=fluid.cpu_places(4))
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_test_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_test_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[1].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[2].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[3].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
scheduler.step()
num += 1
def _test_dygraph(self, python_func, paddle_api, kwarg, place):
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle_api(**kwarg)
sgd = paddle.optimizer.SGD(learning_rate=scheduler,
parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
self.assertAlmostEqual(sgd.current_step_lr(),
python_func(epoch, **kwarg))
if paddle_api.__name__ != "CosineAnnealingLR":
scheduler.step()
else:
scheduler.step(epoch + 1)
def test_scheduler(self):
with self.assertRaises(NotImplementedError):
paddle.optimizer.lr_scheduler._LRScheduler().step()
with self.assertRaises(TypeError):
paddle.optimizer.MultiStepLR(
learning_rate="test", milestones=[1, 2, 3])
with self.assertRaises(TypeError):
paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones='test')
with self.assertRaises(ValueError):
paddle.optimizer.MultiStepLR(
learning_rate=0.5, milestones=[3, 2, 1])
with self.assertRaises(ValueError):
paddle.optimizer.MultiStepLR(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
func_api_kwargs = [(noam_lr, paddle.optimizer.NoamLR, {
"d_model": 0.01,
"warmup_steps": 100,
"verbose": False
}), (piecewise_lr, paddle.optimizer.PiecewiseLR, {
"boundaries": [3, 6, 9, 15, 20],
"values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"verbose": False
}), (natural_exp_lr, paddle.optimizer.NaturalExpLR, {
"learning_rate": 0.5,
"gamma": 0.1,
"verbose": False
}), (inverse_time_lr, paddle.optimizer.InverseTimeLR, {
"learning_rate": 0.5,
"gamma": 0.1,
"verbose": True
}), (polynomial_lr, paddle.optimizer.PolynomialLR, {
"learning_rate": 0.5,
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": False,
"verbose": False
}), (polynomial_lr, paddle.optimizer.PolynomialLR, {
"learning_rate": 0.5,
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": True,
"verbose": False
}), (linear_warmup_lr, paddle.optimizer.LinearLrWarmup, {
'learning_rate': 0.5,
'warmup_steps': 20,
'start_lr': 0,
'end_lr': 0.5,
"verbose": False
}), (exponential_lr, paddle.optimizer.ExponentialLR, {
"learning_rate": 0.5,
"gamma": 0.9,
"verbose": False
}), (multi_step_lr, paddle.optimizer.MultiStepLR, {
"learning_rate": 0.5,
"milestones": [3, 6, 9, 15, 20],
"gamma": 0.8,
"verbose": True
}), (step_lr, paddle.optimizer.StepLR, {
"learning_rate": 0.5,
"step_size": 2,
"gamma": 0.8,
"verbose": False
}), (lambda_lr, paddle.optimizer.LambdaLR, {
"learning_rate": 0.5,
"lr_lambda": lambda x: 0.95**x,
"verbose": False
}), (cosine_annealing_lr, paddle.optimizer.CosineAnnealingLR, {
"learning_rate": 0.5,
"T_max": 10,
"verbose": True
})]
for python_func, paddle_api, kwarg in func_api_kwargs:
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
paddle.enable_static()
self._test_static(python_func, paddle_api, kwarg, place)
paddle.disable_static(place)
self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,7 +19,10 @@ __all__ = [ ...@@ -19,7 +19,10 @@ __all__ = [
'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer', 'LambOptimizer', 'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer', 'LambOptimizer',
'LarsMomentum', 'LarsMomentumOptimizer', 'LookaheadOptimizer', 'LarsMomentum', 'LarsMomentumOptimizer', 'LookaheadOptimizer',
'ModelAverage', 'Momentum', 'MomentumOptimizer', 'PipelineOptimizer', 'ModelAverage', 'Momentum', 'MomentumOptimizer', 'PipelineOptimizer',
'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer' 'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer',
'_LRScheduler', 'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR',
'PolynomialLR', 'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR',
'LambdaLR', 'ReduceLROnPlateau', 'CosineAnnealingLR'
] ]
...@@ -36,3 +39,7 @@ from .adam import Adam ...@@ -36,3 +39,7 @@ from .adam import Adam
from .adamw import AdamW from .adamw import AdamW
from .adamax import Adamax from .adamax import Adamax
from .rmsprop import RMSProp from .rmsprop import RMSProp
from . import lr_scheduler
from .lr_scheduler import _LRScheduler, NoamLR, PiecewiseLR, NaturalExpLR, InverseTimeLR, PolynomialLR, \
LinearLrWarmup, ExponentialLR, MultiStepLR, StepLR, LambdaLR, ReduceLROnPlateau, CosineAnnealingLR
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy
import warnings
from paddle import Tensor
__all__ = [
'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR', 'PolynomialLR',
'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR', 'LambdaLR',
'ReduceLROnPlateau', 'CosineAnnealingLR'
]
class _LRScheduler(object):
"""LRScheduler Base class.
Define the common interface of an LRScheduler.
User can 'form paddle.optimizer.lr_scheduler import _LRScheduler'
And inherit from it to have a custom implementation of get_lr().
"""
def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
if not isinstance(learning_rate, (float, int)):
raise TypeError(
"The type of learning rate must be float, but received {}".
format(type(learning_rate)))
self.base_lr = float(learning_rate)
self.last_lr = float(learning_rate)
self.last_epoch = last_epoch
self.verbose = verbose
self._var_name = None
self.step()
def __call__(self):
"""
Return last computed learning rate on current epoch.
"""
return self.last_lr
def step(self, epoch=None):
"""
'step' should be called after 'minimize' . It will update the learning rate in optimizer according to 'epoch'.
The new learning rate will take effect on next epoch.
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
Examples:
Please refer to the example of current _LRScheduler.
"""
if epoch is None:
self.last_epoch += 1
self.last_lr = self.get_lr()
else:
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
self.last_lr = self._get_closed_form_lr()
else:
self.last_lr = self.get_lr()
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def state_dict(self):
"""
Returns the state of the scheduler as a :class:`dict`.
It is a subset of self.__dict__ .
"""
self._state_keys()
state_dict = {}
for key in self.keys:
if key not in self.__dict__:
continue
value = self.__dict__[key]
if isinstance(value, Tensor):
assert value.shape == [
1
], "shape of Tensor in state_dict must be [1] {}".format(
value.shape)
value = value.numpy()[0]
state_dict[key] = value
return state_dict
# For those subclass who overload _LRScheduler, "last_epoch, last_lr" will be saved by default.
# (Note): you can change it for your subclass.
def _state_keys(self):
"""
set the keys in self.__dict__ that are needed to be saved.
"""
self.keys = ['last_epoch', 'last_lr']
def set_dict(self, state_dict):
"""
Loads the schedulers state.
"""
self._state_keys()
for key in self.keys:
if key in state_dict:
self.__dict__[key] = state_dict[key]
else:
raise RuntimeError(
"Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
format(key))
if len(state_dict) > len(self.keys):
warnings.warn(
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
)
# alias for set_dict
set_state_dict = set_dict
def get_lr(self):
# calculate by python float
raise NotImplementedError
class NoamLR(_LRScheduler):
"""
Applies Noam Lear to the initial learning rate.
The algorithm can be described as following.
.. math::
new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})
Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
Args:
d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``NoamLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.NoamLR(d_model=0.01, warmup_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.NoamLR(d_model=0.01, warmup_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
d_model,
warmup_steps,
learning_rate=1.0,
last_epoch=-1,
verbose=False):
self.d_model = d_model
self.warmup_steps = warmup_steps
super(NoamLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
if self.last_epoch == 0:
a = 1
else:
a = self.last_epoch**-0.5
b = self.warmup_steps**-1.5 * self.last_epoch
return self.base_lr * (self.d_model**-0.5) * min(a, b)
class PiecewiseLR(_LRScheduler):
"""
Piecewise learning rate scheduler.
The algorithm can be described as the code below:
.. code-block:: text
boundaries = [100, 200]
values = [1.0, 0.5, 0.1]
if epoch < 100:
learning_rate = 1.0
elif 100 <= global_step < 200:
learning_rate = 0.5
else:
learning_rate = 0.1
Args:
boundaries(list): A list of steps numbers. The type of element in the list is python int.
values(list): A list of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``PiecewiseLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.PiecewiseLR(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.PiecewiseLR(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
self.boundaries = boundaries
self.values = values
super(PiecewiseLR, self).__init__(
last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
for i in range(len(self.boundaries)):
if self.last_epoch < self.boundaries[i]:
return self.values[i]
return self.values[len(self.values) - 1]
class NaturalExpLR(_LRScheduler):
"""
Applies natural exponential decay to the initial learning rate.
The algorithm can be described as following:
.. math::
new\_learning\_rate = learning\_rate * e^{- gama * epoch}
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): A Ratio to update the learning rate. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``NaturalExpLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.NaturalExpLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.NaturalExpLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(NaturalExpLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)
class InverseTimeLR(_LRScheduler):
"""
Applies inverse time decay to the initial learning rate.
The algorithm can be described as following:
.. math::
new\_learning\_rate = \\frac{learning\_rate}{1 + gamma * epoch}
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``InverseTimeLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.InverseTimeLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.InverseTimeLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(InverseTimeLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr / (1 + self.gamma * self.last_epoch)
class PolynomialLR(_LRScheduler):
"""
Applies polynomial decay to the initial learning rate.
The algorithm can be described as following.
If cycle is set to True, then:
.. math::
decay\_steps & = decay\_steps * math.ceil(\\frac{epoch}{decay\_steps})
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr
If cycle is set to False, then:
.. math::
epoch & = min(epoch, decay\_steps)
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr
Args:
learning_rate (float): The initial learning rate. It is a python float number.
decay_steps(int): The decay step size. It determines the decay cycle.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``PolynomialLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.PolynomialLR(learning_rate=0.5, decay_steps=20, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.PolynomialLR(learning_rate=0.5, decay_steps=20, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
decay_steps,
end_lr=0.0001,
power=1.0,
cycle=False,
last_epoch=-1,
verbose=False):
self.decay_steps = decay_steps
self.end_lr = end_lr
self.power = power
self.cycle = cycle
super(PolynomialLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
tmp_epoch_num = self.last_epoch
tmp_decay_steps = self.decay_steps
if self.cycle:
div_res = math.ceil(
float(self.last_epoch) / float(self.decay_steps))
if self.last_epoch == 0:
div_res = 1
tmp_decay_steps = self.decay_steps * div_res
else:
tmp_epoch_num = min(self.last_epoch, self.decay_steps)
return (self.base_lr - self.end_lr) * (
(1 - float(tmp_epoch_num) / float(tmp_decay_steps)
)**self.power) + self.end_lr
class LinearLrWarmup(_LRScheduler):
"""
Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
When epoch < warmup_steps, learning rate is updated as:
.. code-block:: text
lr = start_lr + (end_lr - start_lr) * (epoch / warmup_steps)
where start_lr is the initial learning rate, and end_lr is the final learning rate;
When epoch >= warmup_steps, learning rate is updated as:
.. code-block:: text
lr = learning_rate
where lr is float or any subclass of ``_LRScheduler`` .
Args:
learning_rate (float|_LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``_LRScheduler`` .
warmup_steps (int): total steps of warm up.
start_lr (float): Initial learning rate of warm up.
end_lr (float): Final learning rate of warm up.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``LinearLrWarmup`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.LinearLrWarmup(
learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.LinearLrWarmup(
learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
warmup_steps,
start_lr,
end_lr,
last_epoch=-1,
verbose=False):
type_check = isinstance(learning_rate, float) or isinstance(
learning_rate, int) or isinstance(learning_rate, _LRScheduler)
if not type_check:
raise TypeError(
"the type of learning_rate should be [int, float or _LRScheduler], the current type is {}".
format(learning_rate))
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.start_lr = start_lr
self.end_lr = end_lr
assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
end_lr, start_lr)
super(LinearLrWarmup, self).__init__(start_lr, last_epoch, verbose)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return (self.end_lr - self.start_lr) * float(
self.last_epoch) / float(self.warmup_steps) + self.start_lr
else:
if isinstance(self.learning_rate, _LRScheduler):
self.learning_rate.step()
return self.learning_rate()
return self.learning_rate
class ExponentialLR(_LRScheduler):
"""
Update learning rate by 'gamma' each epoch.
The algorithm can be described as following.
.. math::
new\_learning\_rate = last\_learning\_rate * gamma
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ExponentialLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ExponentialLR(learning_rate=0.5, gamma=0.9, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.ExponentialLR(learning_rate=0.5, gamma=0.9, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(ExponentialLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr * (self.gamma**self.last_epoch)
class MultiStepLR(_LRScheduler):
"""
Update the learning rate by ``gama`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
milestones = [30, 50]
gamma = 0.1
if epoch < 30:
learning_rate = 0.5
elif epoch < 50:
learning_rate = 0.05
else:
learning_rate = 0.005
Args:
learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiStepLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
milestones,
gamma=0.1,
last_epoch=-1,
verbose=False):
if not isinstance(milestones, (tuple, list)):
raise TypeError(
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
% type(milestones))
if not all([
milestones[i] < milestones[i + 1]
for i in range(len(milestones) - 1)
]):
raise ValueError('The elements of milestones must be incremented')
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
self.milestones = milestones
self.gamma = gamma
super(MultiStepLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
for i in range(len(self.milestones)):
if self.last_epoch < self.milestones[i]:
return self.base_lr * (self.gamma**i)
return self.base_lr * (self.gamma**len(self.milestones))
class StepLR(_LRScheduler):
"""
Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
step_size = 30
gamma = 0.1
learning_rate = 0.5 if epoch < 30
learning_rate = 0.05 if 30 <= epoch < 60
learning_rate = 0.005 if 60 <= epoch < 90
...
Args:
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``StepLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.StepLR(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.StepLR(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
step_size,
gamma=0.1,
last_epoch=-1,
verbose=False):
if not isinstance(step_size, int):
raise TypeError(
"The type of 'step_size' must be 'int', but received %s." %
type(step_size))
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
i = self.last_epoch // self.step_size
return self.base_lr * (self.gamma**i)
class LambdaLR(_LRScheduler):
"""
Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5 # init learning_rate
lr_lambda = lambda epoch: 0.95 ** epoch
learning_rate = 0.5 # epoch 0
learning_rate = 0.475 # epoch 1
learning_rate = 0.45125 # epoch 2
Args:
learning_rate (float): The initial learning rate. It is a python float number.
lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``LambdaLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.LambdaLR(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.LambdaLR(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
if not callable(lr_lambda):
raise TypeError(
"The type of 'lr_lambda' in 'LambdaLR' must be 'function', but received %s."
% type(lr_lambda))
self.lr_lambda = lr_lambda
super(LambdaLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr * self.lr_lambda(self.last_epoch)
class ReduceLROnPlateau(_LRScheduler):
"""
Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
by 2 to 10 times once model performance has no longer improvement.
The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
number of epochs, the learning rate will be reduced.)
In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.
Args:
learning_rate (float): The initial learning rate. It is a python float number.
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
It should be less than 1.0. Default: 0.1.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
Default: 10.
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
Returns:
``ReduceLROnPlateau`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step(loss)
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step(out[0])
"""
def __init__(self,
learning_rate,
mode='min',
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode='rel',
cooldown=0,
min_lr=0,
epsilon=1e-8,
verbose=False):
mode = mode.lower()
if mode not in ['min', 'max']:
raise ValueError('mode: ' + mode + ' is unknown!')
self.mode = mode
if factor >= 1.0:
raise ValueError(
'new_lr = origin_lr * gamma and gamma should be < 1.0.')
self.factor = factor
threshold_mode = threshold_mode.lower()
if threshold_mode not in ['rel', 'abs']:
raise ValueError('threshold mode: ' + threshold_mode +
' is unknown!')
self.threshold_mode = threshold_mode
if not isinstance(learning_rate, (float, int)):
raise TypeError(
"The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float', but received %s."
% type(learning_rate))
self.verbose = verbose
self.patience = patience
self.threshold = threshold
self.threshold_mode = threshold_mode
self.cooldown = cooldown
self.min_lr = min_lr
self.epsilon = epsilon
self.cooldown_counter = 0
self.best = None
self.num_bad_epochs = 0
# Can not call Parent __init__, so implement here.
self.base_lr = float(learning_rate)
self.last_lr = float(learning_rate)
self.last_epoch = 0
self.verbose = verbose
self._var_name = None
# "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
def _state_keys(self):
self.keys = [
'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
'last_lr'
]
def step(self, metrics, epoch=None):
"""
step should be called after 'minimize' . It will update the learning rate in optimizer according to ``metrics`` .
The new learning rate will take effect on next epoch.
Args:
metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
'numpy.ndarray', its shape must be [1].
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
Examples:
Please refer to the example of current _LRScheduler.
"""
if epoch is None:
self.last_epoch = self.last_epoch + 1
else:
self.last_epoch = epoch
# loss must be 1-D Tensor with shape [1]
if isinstance(metrics, (Tensor, numpy.ndarray)):
assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
"should be (1L,), but the current metrics.shape is {}. Maybe that " \
"you should call paddle.mean to process it first.".format(loss.shape)
elif not isinstance(metrics,
(int, float, numpy.float32, numpy.float64)):
raise TypeError(
"metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".
format(type(metrics)))
if self.cooldown_counter > 0:
self.cooldown_counter -= 1
else:
if self.best is None or self._is_better(metrics, self.best):
self.best = metrics
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.num_bad_epochs > self.patience:
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
new_lr = max(self.last_lr * self.factor, self.min_lr)
if self.last_lr - new_lr > self.epsilon:
self.last_lr = new_lr
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__,
self.last_lr))
def _is_better(self, current, best):
print("mode", self.mode, 'threshold_mode', self.threshold_mode)
if self.mode == 'min' and self.threshold_mode == 'rel':
return current < best - best * self.threshold
elif self.mode == 'min' and self.threshold_mode == 'abs':
return current < best - self.threshold
elif self.mode == 'max' and self.threshold_mode == 'rel':
return current > best + best * self.threshold
else:
return current > best + self.threshold
class CosineAnnealingLR(_LRScheduler):
"""
Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
SGDR:
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
The algorithm can be described as following.
.. math::
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
Note that this only implements the cosine annealing part of SGDR, and not the restarts.
Args:
learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate.
eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CosineAnnealingLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.CosineAnnealingLR(learning_rate=0.5, T_max=10, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.CosineAnnealingLR(learning_rate=0.5, T_max=10, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
T_max,
eta_min=0,
last_epoch=-1,
verbose=False):
if not isinstance(T_max, int):
raise TypeError(
"The type of 'T_max' in 'CosineAnnealingLR' must be 'int', but received %s."
% type(T_max))
if not isinstance(eta_min, (float, int)):
raise TypeError(
"The type of 'eta_min' in 'CosineAnnealingLR' must be 'float, int', but received %s."
% type(eta_min))
self.T_max = T_max
self.eta_min = float(eta_min)
super(CosineAnnealingLR, self).__init__(learning_rate, last_epoch,
verbose)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
math.pi / self.T_max)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
self.last_lr - self.eta_min) + self.eta_min
def _get_closed_form_lr(self):
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
math.pi * self.last_epoch / self.T_max)) / 2
...@@ -21,6 +21,7 @@ __all__ = [ ...@@ -21,6 +21,7 @@ __all__ = [
'load', 'data', 'InputSpec' 'load', 'data', 'InputSpec'
] ]
from . import nn
from .input import data #DEFINE_ALIAS from .input import data #DEFINE_ALIAS
from .input import InputSpec #DEFINE_ALIAS from .input import InputSpec #DEFINE_ALIAS
from ..fluid.executor import Executor #DEFINE_ALIAS from ..fluid.executor import Executor #DEFINE_ALIAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册