未验证 提交 407de039 编写于 作者: Z Zhou Wei 提交者: GitHub

[2.0API] Reconstruct all API related to LR Scheduler, unify dygraph and static (#26550)

* Reconstruct all API related to lr scheduler, unify dygraph and static

* Reconstruct all API related to lr scheduler, unify dygraph and static

* fix doc

* fix doc

* fix doc of lr_scheduler

* fix unittest and english doc

* fix english doc

* fix confilt

* fix doc
上级 6e823cfe
......@@ -850,6 +850,7 @@ class Executor(object):
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy, return_merged):
from paddle.optimizer.lr_scheduler import _LRScheduler
exe = program._executor
# TODO(zhenghuihuang): quantization uses Graph in CompiledProgram
# instead of program. We will add support for checking Vars in Graph
......@@ -893,6 +894,16 @@ class Executor(object):
res.append(res_dict)
exe.feed_tensors_into_local_scopes(res)
if hasattr(program._program, 'lr_sheduler'):
lr_sheduler = program._program.lr_sheduler
assert isinstance(lr_sheduler, _LRScheduler), "must be _LRScheduler"
lr_value = lr_sheduler()
lr_var = program._program.global_block().vars[lr_sheduler._var_name]
lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype)
exe.feed_and_split_tensor_into_local_scopes({
lr_sheduler._var_name: lr_tensor
})
fetch_var_names = list(map(_to_name_str, fetch_list))
tensors = exe.run(fetch_var_names, return_merged)._move_to_list()
return as_numpy(tensors) if return_numpy else tensors
......@@ -1222,7 +1233,7 @@ class Executor(object):
def _run_program(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache):
from paddle.optimizer.lr_scheduler import _LRScheduler
if feed is None:
feed = {}
elif isinstance(feed, (list, tuple)):
......@@ -1278,6 +1289,16 @@ class Executor(object):
fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
assert isinstance(program.lr_sheduler,
_LRScheduler), "must be _LRScheduler"
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope, lr_sheduler._var_name)
tensor.set(data, self.place)
if not use_program_cache:
self._default_executor.run(program.desc, scope, 0, True, True,
fetch_var_name)
......
......@@ -4450,6 +4450,8 @@ class Program(object):
p._current_role = self._current_role
p.__op_role_var = self.__op_role_var
p._appending_grad_times = self._appending_grad_times
if hasattr(self, 'lr_sheduler'):
p.lr_sheduler = self.lr_sheduler
#NOTE(zhiqiu): we sync the cloned program, to update its program by
# its desc.
......
......@@ -68,14 +68,16 @@ class Optimizer(object):
regularization=None,
grad_clip=None,
name=None):
# Because of the loop import, so place it in the function body
from paddle.optimizer.lr_scheduler import _LRScheduler
self._parameter_list = list(
parameter_list) if parameter_list is not None else None
self._name = name
if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, LearningRateDecay):
if not isinstance(learning_rate,
(float, LearningRateDecay, _LRScheduler)):
raise TypeError(
"learning rate should be float or LearningRateDecay, got %s here"
"learning rate should be float or _LRScheduler, got %s here"
% type(learning_rate))
if self._parameter_list is None:
raise AttributeError(
......@@ -90,11 +92,11 @@ class Optimizer(object):
% regularization.__str__())
break
else:
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
if not isinstance(learning_rate,
(float, framework.Variable, _LRScheduler)):
raise TypeError(
"learning rate should be float or Variable, got %s here" %
type(learning_rate))
"learning rate should be float or _LRScheduler, got %s here"
% type(learning_rate))
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
......@@ -144,11 +146,15 @@ class Optimizer(object):
state_dict = adam.state_dict()
'''
from paddle.optimizer.lr_scheduler import _LRScheduler
state_dict = {}
for k, v in self._accumulators.items():
for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp
# global step if use lr decay
if isinstance(self._learning_rate, _LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
return state_dict
if isinstance(self._learning_rate, LearningRateDecay):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
......@@ -192,6 +198,9 @@ class Optimizer(object):
adam.set_dict(opti_state_dict)
'''
from paddle.optimizer.lr_scheduler import _LRScheduler
if isinstance(self._learning_rate, _LRScheduler):
self._learning_rate.set_dict(state_dict["LR_Scheduler"])
if isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate.set_dict(state_dict["LR_Scheduler"])
......@@ -252,6 +261,30 @@ class Optimizer(object):
return self._opti_name_list
def _create_global_learning_rate(self):
from paddle.optimizer.lr_scheduler import _LRScheduler
if isinstance(self._learning_rate, _LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype='float32' if self._dtype is None else self._dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
self._learning_rate_map[framework.default_main_program(
)] = lr_var
lr_value = float(self._learning_rate())
self.helper.set_variable_initializer(
lr_var, initializer=Constant(value=lr_value))
return
if imperative_base.enabled():
# create learning rate Variable
if isinstance(self._learning_rate, float):
......
......@@ -19,6 +19,7 @@ import math
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.framework as framework
......@@ -553,79 +554,459 @@ def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
class TestReduceLROnPlateauDecay(unittest.TestCase):
def test_dygraph_mode(self):
with fluid.dygraph.guard():
# the decay rate must be less than 1.0
with self.assertRaises(ValueError):
fluid.dygraph.ReduceLROnPlateau(
learning_rate=1.0, decay_rate=2.0)
# the mode must be "min" or "max"
with self.assertRaises(ValueError):
fluid.dygraph.ReduceLROnPlateau(learning_rate=1.0, mode="test")
# the threshold_mode must be "rel" or "abs"
with self.assertRaises(ValueError):
fluid.dygraph.ReduceLROnPlateau(
learning_rate=1.0, threshold_mode="test")
base_lr = 1.0
patience = 3
cooldown = 1
decay_rate = 0.5
threshold = 1e-4
linear = fluid.dygraph.Linear(10, 10)
def test_ReduceLR(self):
# the decay rate must be less than 1.0
with self.assertRaises(ValueError):
paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=2.0)
# the mode must be "min" or "max"
with self.assertRaises(ValueError):
paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, mode="test")
# the threshold_mode must be "rel" or "abs"
with self.assertRaises(ValueError):
paddle.optimizer.ReduceLROnPlateau(
learning_rate=1.0, threshold_mode="test")
with self.assertRaises(TypeError):
paddle.optimizer.ReduceLROnPlateau(learning_rate="test")
with self.assertRaises(TypeError):
paddle.optimizer.ReduceLROnPlateau(learning_rate=0.5).step("test")
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
for m, n in zip(['min', 'max', 'min', 'max'],
['rel', 'rel', 'abs', 'abs']):
kwargs = {
'learning_rate': base_lr,
'decay_rate': decay_rate,
'threshold': threshold,
'verbose': True,
'patience': patience,
'cooldown': cooldown,
'learning_rate': 1.0,
'mode': m,
'factor': 0.5,
'patience': 3,
'threshold': 1e-4,
'threshold_mode': n,
'eps': 1e-6
'cooldown': 1,
'min_lr': 0,
'epsilon': 1e-8,
'verbose': False,
}
print("class=" + fluid.dygraph.ReduceLROnPlateau.__name__ +
" kwargs=" + str(kwargs))
lr = fluid.dygraph.ReduceLROnPlateau(**kwargs)
sgd = fluid.optimizer.SGD(learning_rate=lr,
parameter_list=linear.parameters())
best = float("-10000") if m == "max" else float("10000")
expected_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, expected_lr, cooldown_counter, num_bad_epochs]
step_num = 0
epoch_num = 0
for epoch in range(30):
total_loss = 0
for batch_id in range(2):
step_num += 1
x = fluid.dygraph.to_variable(
np.array([step_num]).astype('float32'))
loss = layers.sin(x)
sgd.minimize(loss)
total_loss += loss
epoch_num += 1
# get expected lr from fluid
avg_loss = total_loss / 1
lr.step(avg_loss)
actual_lr = lr().numpy()[0]
# get expected lr form python
expected_lr = reduce_lr_on_plateau(decay_rate, threshold,
cooldown, patience, m, n,
avg_loss, var_list)
self.assertEqual(
expected_lr,
actual_lr,
msg='Failed reduce lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
format(epoch_num, expected_lr, actual_lr))
paddle.enable_static()
self._test_static(place, kwargs)
paddle.disable_static(place)
self._test_dygraph(place, kwargs)
paddle.enable_static()
def _test_static(self, place, kwargs):
paddle.enable_static()
best = float("-10000") if kwargs['mode'] == "max" else float("10000")
current_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, current_lr, cooldown_counter, num_bad_epochs]
main_prog = fluid.Program()
start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.layers.create_global_var(
[1], 1, 'float32', persistable=True)
paddle.increment(x)
loss = paddle.sin(x)
scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
adam = fluid.optimizer.Adam(learning_rate=scheduler)
adam.minimize(loss)
lr_var = adam._global_learning_rate()
test_prog = main_prog.clone()
exe = fluid.Executor(place)
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(1):
out, actual_lr = exe.run(main_prog,
fetch_list=[loss.name, lr_var.name])
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'],
kwargs['threshold_mode'], out[0], var_list)
scheduler.step(out[0])
actual_lr = scheduler()
self.assertEqual(actual_lr, np.array(expected_lr))
for epoch in range(10):
for batch_id in range(1):
out, actual_lr = exe.run(test_prog,
fetch_list=[loss.name, lr_var.name])
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'],
kwargs['threshold_mode'], out[0], var_list)
scheduler.step(out[0])
actual_lr = scheduler()
self.assertEqual(actual_lr, np.array(expected_lr))
def _test_dygraph(self, place, kwargs):
paddle.disable_static(place)
best = float("-10000") if kwargs['mode'] == "max" else float("10000")
current_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, current_lr, cooldown_counter, num_bad_epochs]
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
sgd = paddle.optimizer.SGD(learning_rate=scheduler,
parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(1):
x = paddle.to_tensor(epoch).astype('float32')
loss = paddle.sin(x)
loss.backward()
sgd.minimize(loss)
scheduler.step(loss)
# get lr from paddle
current_lr = scheduler()
# get lr form python
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'], kwargs['threshold_mode'],
loss, var_list)
self.assertEqual(current_lr, expected_lr)
state_dict = sgd.state_dict()
scheduler1 = paddle.optimizer.ReduceLROnPlateau(**kwargs)
sgd1 = paddle.optimizer.SGD(learning_rate=scheduler1,
parameter_list=linear.parameters())
sgd1.set_dict(state_dict)
self.assertEqual(scheduler.cooldown_counter,
scheduler1.cooldown_counter)
self.assertEqual(scheduler.best.numpy()[0], scheduler1.best)
self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
def noam_lr(epoch_num, d_model, warmup_steps, learning_rate=1.0, verbose=False):
if epoch_num == 0:
a = 1
else:
a = math.pow(epoch_num, -0.5)
b = math.pow(warmup_steps, -1.5) * epoch_num
return learning_rate * math.pow(d_model, -0.5) * min(a, b)
def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
return learning_rate * lr_lambda(epoch_num)
def piecewise_lr(epoch_num, boundaries, values, verbose=False):
assert len(boundaries) + 1 == len(values)
for i in range(len(boundaries)):
if epoch_num < boundaries[i]:
return values[i]
return values[len(values) - 1]
def exponential_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate * gamma**epoch_num
def natural_exp_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate * math.exp(-1 * gamma * epoch_num)
def inverse_time_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate / (1 + gamma * epoch_num)
def polynomial_lr(epoch_num,
learning_rate,
decay_steps,
end_lr=0.0001,
power=1.0,
cycle=False,
verbose=False):
if cycle:
div = math.ceil(epoch_num / float(decay_steps))
if epoch_num == 0:
div = 1
decay_steps = decay_steps * div
else:
epoch_num = min(epoch_num, decay_steps)
return (learning_rate - end_lr) * (
(1 - float(epoch_num) / float(decay_steps))**power) + end_lr
def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
math.pi / self.T_max)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
self.last_lr - self.eta_min) + self.eta_min
cosine_annealing_lr_current = None
def cosine_annealing_lr(epoch_num,
learning_rate,
T_max,
eta_min=0,
verbose=False):
global cosine_annealing_lr_current
if epoch_num == 0:
cosine_annealing_lr_current = learning_rate
elif (epoch_num - 1 - T_max) % (2 * T_max) == 0:
cosine_annealing_lr_current = cosine_annealing_lr_current + (
learning_rate - eta_min) * (1 - math.cos(math.pi / float(T_max))
) / 2
else:
cosine_annealing_lr_current = (1 + math.cos(
math.pi * epoch_num / float(T_max))) / (1 + math.cos(math.pi * (
epoch_num - 1) / float(T_max))) * (cosine_annealing_lr_current -
eta_min) + eta_min
return cosine_annealing_lr_current
def linear_warmup_lr(epoch_num,
learning_rate,
warmup_steps,
start_lr,
end_lr,
verbose=False):
if epoch_num < warmup_steps:
return start_lr + (end_lr - start_lr) * (float(epoch_num) /
float(warmup_steps))
else:
return learning_rate
def multi_step_lr(epoch_num,
learning_rate,
milestones,
gamma=0.1,
verbose=False):
for i in range(len(milestones)):
if epoch_num < milestones[i]:
return learning_rate * (gamma**i)
return learning_rate * (gamma**len(milestones))
def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
return learning_rate * math.pow(gamma, epoch_num // step_size)
class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
main_prog = fluid.Program()
start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.data(name='x', shape=[3, 4, 5])
y = fluid.data(name='y', shape=[3, 4, 5])
z = fluid.layers.fc(x, 100)
loss = fluid.layers.mean(z)
scheduler = paddle_api(**kwarg)
adam = fluid.optimizer.Adam(learning_rate=scheduler)
adam.minimize(loss)
lr_var = adam._global_learning_rate()
test_prog = main_prog.clone()
num = 0
exe = fluid.Executor(place)
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
num += 1
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
test_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
num += 1
if isinstance(place, fluid.CPUPlace):
compiled_train_prog = fluid.CompiledProgram(
main_prog).with_data_parallel(
loss_name=loss.name, places=fluid.cpu_places(4))
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_train_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_train_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[1].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[2].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[3].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
scheduler.step()
num += 1
compiled_test_prog = fluid.CompiledProgram(
test_prog).with_data_parallel(
loss_name=loss.name,
share_vars_from=compiled_train_prog,
places=fluid.cpu_places(4))
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_test_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_test_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[1].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[2].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[3].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
scheduler.step()
num += 1
def _test_dygraph(self, python_func, paddle_api, kwarg, place):
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle_api(**kwarg)
sgd = paddle.optimizer.SGD(learning_rate=scheduler,
parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
self.assertAlmostEqual(sgd.current_step_lr(),
python_func(epoch, **kwarg))
if paddle_api.__name__ != "CosineAnnealingLR":
scheduler.step()
else:
scheduler.step(epoch + 1)
def test_scheduler(self):
with self.assertRaises(NotImplementedError):
paddle.optimizer.lr_scheduler._LRScheduler().step()
with self.assertRaises(TypeError):
paddle.optimizer.MultiStepLR(
learning_rate="test", milestones=[1, 2, 3])
with self.assertRaises(TypeError):
paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones='test')
with self.assertRaises(ValueError):
paddle.optimizer.MultiStepLR(
learning_rate=0.5, milestones=[3, 2, 1])
with self.assertRaises(ValueError):
paddle.optimizer.MultiStepLR(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
func_api_kwargs = [(noam_lr, paddle.optimizer.NoamLR, {
"d_model": 0.01,
"warmup_steps": 100,
"verbose": False
}), (piecewise_lr, paddle.optimizer.PiecewiseLR, {
"boundaries": [3, 6, 9, 15, 20],
"values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"verbose": False
}), (natural_exp_lr, paddle.optimizer.NaturalExpLR, {
"learning_rate": 0.5,
"gamma": 0.1,
"verbose": False
}), (inverse_time_lr, paddle.optimizer.InverseTimeLR, {
"learning_rate": 0.5,
"gamma": 0.1,
"verbose": True
}), (polynomial_lr, paddle.optimizer.PolynomialLR, {
"learning_rate": 0.5,
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": False,
"verbose": False
}), (polynomial_lr, paddle.optimizer.PolynomialLR, {
"learning_rate": 0.5,
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": True,
"verbose": False
}), (linear_warmup_lr, paddle.optimizer.LinearLrWarmup, {
'learning_rate': 0.5,
'warmup_steps': 20,
'start_lr': 0,
'end_lr': 0.5,
"verbose": False
}), (exponential_lr, paddle.optimizer.ExponentialLR, {
"learning_rate": 0.5,
"gamma": 0.9,
"verbose": False
}), (multi_step_lr, paddle.optimizer.MultiStepLR, {
"learning_rate": 0.5,
"milestones": [3, 6, 9, 15, 20],
"gamma": 0.8,
"verbose": True
}), (step_lr, paddle.optimizer.StepLR, {
"learning_rate": 0.5,
"step_size": 2,
"gamma": 0.8,
"verbose": False
}), (lambda_lr, paddle.optimizer.LambdaLR, {
"learning_rate": 0.5,
"lr_lambda": lambda x: 0.95**x,
"verbose": False
}), (cosine_annealing_lr, paddle.optimizer.CosineAnnealingLR, {
"learning_rate": 0.5,
"T_max": 10,
"verbose": True
})]
for python_func, paddle_api, kwarg in func_api_kwargs:
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
paddle.enable_static()
self._test_static(python_func, paddle_api, kwarg, place)
paddle.disable_static(place)
self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static()
if __name__ == '__main__':
......
......@@ -19,7 +19,10 @@ __all__ = [
'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer', 'LambOptimizer',
'LarsMomentum', 'LarsMomentumOptimizer', 'LookaheadOptimizer',
'ModelAverage', 'Momentum', 'MomentumOptimizer', 'PipelineOptimizer',
'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer'
'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer',
'_LRScheduler', 'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR',
'PolynomialLR', 'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR',
'LambdaLR', 'ReduceLROnPlateau', 'CosineAnnealingLR'
]
......@@ -36,3 +39,7 @@ from .adam import Adam
from .adamw import AdamW
from .adamax import Adamax
from .rmsprop import RMSProp
from . import lr_scheduler
from .lr_scheduler import _LRScheduler, NoamLR, PiecewiseLR, NaturalExpLR, InverseTimeLR, PolynomialLR, \
LinearLrWarmup, ExponentialLR, MultiStepLR, StepLR, LambdaLR, ReduceLROnPlateau, CosineAnnealingLR
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy
import warnings
from paddle import Tensor
__all__ = [
'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR', 'PolynomialLR',
'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR', 'LambdaLR',
'ReduceLROnPlateau', 'CosineAnnealingLR'
]
class _LRScheduler(object):
"""LRScheduler Base class.
Define the common interface of an LRScheduler.
User can 'form paddle.optimizer.lr_scheduler import _LRScheduler'
And inherit from it to have a custom implementation of get_lr().
"""
def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
if not isinstance(learning_rate, (float, int)):
raise TypeError(
"The type of learning rate must be float, but received {}".
format(type(learning_rate)))
self.base_lr = float(learning_rate)
self.last_lr = float(learning_rate)
self.last_epoch = last_epoch
self.verbose = verbose
self._var_name = None
self.step()
def __call__(self):
"""
Return last computed learning rate on current epoch.
"""
return self.last_lr
def step(self, epoch=None):
"""
'step' should be called after 'minimize' . It will update the learning rate in optimizer according to 'epoch'.
The new learning rate will take effect on next epoch.
Args:
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
Examples:
Please refer to the example of current _LRScheduler.
"""
if epoch is None:
self.last_epoch += 1
self.last_lr = self.get_lr()
else:
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
self.last_lr = self._get_closed_form_lr()
else:
self.last_lr = self.get_lr()
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__, self.last_lr))
def state_dict(self):
"""
Returns the state of the scheduler as a :class:`dict`.
It is a subset of self.__dict__ .
"""
self._state_keys()
state_dict = {}
for key in self.keys:
if key not in self.__dict__:
continue
value = self.__dict__[key]
if isinstance(value, Tensor):
assert value.shape == [
1
], "shape of Tensor in state_dict must be [1] {}".format(
value.shape)
value = value.numpy()[0]
state_dict[key] = value
return state_dict
# For those subclass who overload _LRScheduler, "last_epoch, last_lr" will be saved by default.
# (Note): you can change it for your subclass.
def _state_keys(self):
"""
set the keys in self.__dict__ that are needed to be saved.
"""
self.keys = ['last_epoch', 'last_lr']
def set_dict(self, state_dict):
"""
Loads the schedulers state.
"""
self._state_keys()
for key in self.keys:
if key in state_dict:
self.__dict__[key] = state_dict[key]
else:
raise RuntimeError(
"Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
format(key))
if len(state_dict) > len(self.keys):
warnings.warn(
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
)
# alias for set_dict
set_state_dict = set_dict
def get_lr(self):
# calculate by python float
raise NotImplementedError
class NoamLR(_LRScheduler):
"""
Applies Noam Lear to the initial learning rate.
The algorithm can be described as following.
.. math::
new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})
Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
Args:
d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``NoamLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.NoamLR(d_model=0.01, warmup_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.NoamLR(d_model=0.01, warmup_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
d_model,
warmup_steps,
learning_rate=1.0,
last_epoch=-1,
verbose=False):
self.d_model = d_model
self.warmup_steps = warmup_steps
super(NoamLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
if self.last_epoch == 0:
a = 1
else:
a = self.last_epoch**-0.5
b = self.warmup_steps**-1.5 * self.last_epoch
return self.base_lr * (self.d_model**-0.5) * min(a, b)
class PiecewiseLR(_LRScheduler):
"""
Piecewise learning rate scheduler.
The algorithm can be described as the code below:
.. code-block:: text
boundaries = [100, 200]
values = [1.0, 0.5, 0.1]
if epoch < 100:
learning_rate = 1.0
elif 100 <= global_step < 200:
learning_rate = 0.5
else:
learning_rate = 0.1
Args:
boundaries(list): A list of steps numbers. The type of element in the list is python int.
values(list): A list of learning rate values that will be picked during different epoch boundaries.
The type of element in the list is python float.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``PiecewiseLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.PiecewiseLR(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.PiecewiseLR(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
self.boundaries = boundaries
self.values = values
super(PiecewiseLR, self).__init__(
last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
for i in range(len(self.boundaries)):
if self.last_epoch < self.boundaries[i]:
return self.values[i]
return self.values[len(self.values) - 1]
class NaturalExpLR(_LRScheduler):
"""
Applies natural exponential decay to the initial learning rate.
The algorithm can be described as following:
.. math::
new\_learning\_rate = learning\_rate * e^{- gama * epoch}
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): A Ratio to update the learning rate. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``NaturalExpLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.NaturalExpLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.NaturalExpLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(NaturalExpLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)
class InverseTimeLR(_LRScheduler):
"""
Applies inverse time decay to the initial learning rate.
The algorithm can be described as following:
.. math::
new\_learning\_rate = \\frac{learning\_rate}{1 + gamma * epoch}
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``InverseTimeLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.InverseTimeLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on static mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.InverseTimeLR(learning_rate=0.5, gamma=0.1, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(InverseTimeLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr / (1 + self.gamma * self.last_epoch)
class PolynomialLR(_LRScheduler):
"""
Applies polynomial decay to the initial learning rate.
The algorithm can be described as following.
If cycle is set to True, then:
.. math::
decay\_steps & = decay\_steps * math.ceil(\\frac{epoch}{decay\_steps})
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr
If cycle is set to False, then:
.. math::
epoch & = min(epoch, decay\_steps)
new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr
Args:
learning_rate (float): The initial learning rate. It is a python float number.
decay_steps(int): The decay step size. It determines the decay cycle.
end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
power(float, optional): Power of polynomial. Default: 1.0.
cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``PolynomialLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.PolynomialLR(learning_rate=0.5, decay_steps=20, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.PolynomialLR(learning_rate=0.5, decay_steps=20, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
decay_steps,
end_lr=0.0001,
power=1.0,
cycle=False,
last_epoch=-1,
verbose=False):
self.decay_steps = decay_steps
self.end_lr = end_lr
self.power = power
self.cycle = cycle
super(PolynomialLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
tmp_epoch_num = self.last_epoch
tmp_decay_steps = self.decay_steps
if self.cycle:
div_res = math.ceil(
float(self.last_epoch) / float(self.decay_steps))
if self.last_epoch == 0:
div_res = 1
tmp_decay_steps = self.decay_steps * div_res
else:
tmp_epoch_num = min(self.last_epoch, self.decay_steps)
return (self.base_lr - self.end_lr) * (
(1 - float(tmp_epoch_num) / float(tmp_decay_steps)
)**self.power) + self.end_lr
class LinearLrWarmup(_LRScheduler):
"""
Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
When epoch < warmup_steps, learning rate is updated as:
.. code-block:: text
lr = start_lr + (end_lr - start_lr) * (epoch / warmup_steps)
where start_lr is the initial learning rate, and end_lr is the final learning rate;
When epoch >= warmup_steps, learning rate is updated as:
.. code-block:: text
lr = learning_rate
where lr is float or any subclass of ``_LRScheduler`` .
Args:
learning_rate (float|_LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``_LRScheduler`` .
warmup_steps (int): total steps of warm up.
start_lr (float): Initial learning rate of warm up.
end_lr (float): Final learning rate of warm up.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``LinearLrWarmup`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.LinearLrWarmup(
learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.LinearLrWarmup(
learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
warmup_steps,
start_lr,
end_lr,
last_epoch=-1,
verbose=False):
type_check = isinstance(learning_rate, float) or isinstance(
learning_rate, int) or isinstance(learning_rate, _LRScheduler)
if not type_check:
raise TypeError(
"the type of learning_rate should be [int, float or _LRScheduler], the current type is {}".
format(learning_rate))
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.start_lr = start_lr
self.end_lr = end_lr
assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
end_lr, start_lr)
super(LinearLrWarmup, self).__init__(start_lr, last_epoch, verbose)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return (self.end_lr - self.start_lr) * float(
self.last_epoch) / float(self.warmup_steps) + self.start_lr
else:
if isinstance(self.learning_rate, _LRScheduler):
self.learning_rate.step()
return self.learning_rate()
return self.learning_rate
class ExponentialLR(_LRScheduler):
"""
Update learning rate by 'gamma' each epoch.
The algorithm can be described as following.
.. math::
new\_learning\_rate = last\_learning\_rate * gamma
Args:
learning_rate (float): The initial learning rate. It is a python float number.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ExponentialLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ExponentialLR(learning_rate=0.5, gamma=0.9, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.ExponentialLR(learning_rate=0.5, gamma=0.9, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(ExponentialLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr * (self.gamma**self.last_epoch)
class MultiStepLR(_LRScheduler):
"""
Update the learning rate by ``gama`` once ``epoch`` reaches one of the milestones.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
milestones = [30, 50]
gamma = 0.1
if epoch < 30:
learning_rate = 0.5
elif epoch < 50:
learning_rate = 0.05
else:
learning_rate = 0.005
Args:
learning_rate (float): The initial learning rate. It is a python float number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``MultiStepLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
milestones,
gamma=0.1,
last_epoch=-1,
verbose=False):
if not isinstance(milestones, (tuple, list)):
raise TypeError(
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
% type(milestones))
if not all([
milestones[i] < milestones[i + 1]
for i in range(len(milestones) - 1)
]):
raise ValueError('The elements of milestones must be incremented')
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
self.milestones = milestones
self.gamma = gamma
super(MultiStepLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
for i in range(len(self.milestones)):
if self.last_epoch < self.milestones[i]:
return self.base_lr * (self.gamma**i)
return self.base_lr * (self.gamma**len(self.milestones))
class StepLR(_LRScheduler):
"""
Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5
step_size = 30
gamma = 0.1
learning_rate = 0.5 if epoch < 30
learning_rate = 0.05 if 30 <= epoch < 60
learning_rate = 0.005 if 60 <= epoch < 90
...
Args:
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``StepLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.StepLR(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.StepLR(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
step_size,
gamma=0.1,
last_epoch=-1,
verbose=False):
if not isinstance(step_size, int):
raise TypeError(
"The type of 'step_size' must be 'int', but received %s." %
type(step_size))
if gamma >= 1.0:
raise ValueError('gamma should be < 1.0.')
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
i = self.last_epoch // self.step_size
return self.base_lr * (self.gamma**i)
class LambdaLR(_LRScheduler):
"""
Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .
The algorithm can be described as the code below.
.. code-block:: text
learning_rate = 0.5 # init learning_rate
lr_lambda = lambda epoch: 0.95 ** epoch
learning_rate = 0.5 # epoch 0
learning_rate = 0.475 # epoch 1
learning_rate = 0.45125 # epoch 2
Args:
learning_rate (float): The initial learning rate. It is a python float number.
lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``LambdaLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.LambdaLR(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.LambdaLR(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
if not callable(lr_lambda):
raise TypeError(
"The type of 'lr_lambda' in 'LambdaLR' must be 'function', but received %s."
% type(lr_lambda))
self.lr_lambda = lr_lambda
super(LambdaLR, self).__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr * self.lr_lambda(self.last_epoch)
class ReduceLROnPlateau(_LRScheduler):
"""
Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
by 2 to 10 times once model performance has no longer improvement.
The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
(Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
number of epochs, the learning rate will be reduced.)
In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.
Args:
learning_rate (float): The initial learning rate. It is a python float number.
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning
rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
It should be less than 1.0. Default: 0.1.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
Default: 10.
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
Returns:
``ReduceLROnPlateau`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step(loss)
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step(out[0])
"""
def __init__(self,
learning_rate,
mode='min',
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode='rel',
cooldown=0,
min_lr=0,
epsilon=1e-8,
verbose=False):
mode = mode.lower()
if mode not in ['min', 'max']:
raise ValueError('mode: ' + mode + ' is unknown!')
self.mode = mode
if factor >= 1.0:
raise ValueError(
'new_lr = origin_lr * gamma and gamma should be < 1.0.')
self.factor = factor
threshold_mode = threshold_mode.lower()
if threshold_mode not in ['rel', 'abs']:
raise ValueError('threshold mode: ' + threshold_mode +
' is unknown!')
self.threshold_mode = threshold_mode
if not isinstance(learning_rate, (float, int)):
raise TypeError(
"The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float', but received %s."
% type(learning_rate))
self.verbose = verbose
self.patience = patience
self.threshold = threshold
self.threshold_mode = threshold_mode
self.cooldown = cooldown
self.min_lr = min_lr
self.epsilon = epsilon
self.cooldown_counter = 0
self.best = None
self.num_bad_epochs = 0
# Can not call Parent __init__, so implement here.
self.base_lr = float(learning_rate)
self.last_lr = float(learning_rate)
self.last_epoch = 0
self.verbose = verbose
self._var_name = None
# "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
def _state_keys(self):
self.keys = [
'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
'last_lr'
]
def step(self, metrics, epoch=None):
"""
step should be called after 'minimize' . It will update the learning rate in optimizer according to ``metrics`` .
The new learning rate will take effect on next epoch.
Args:
metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
'numpy.ndarray', its shape must be [1].
epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
None
Examples:
Please refer to the example of current _LRScheduler.
"""
if epoch is None:
self.last_epoch = self.last_epoch + 1
else:
self.last_epoch = epoch
# loss must be 1-D Tensor with shape [1]
if isinstance(metrics, (Tensor, numpy.ndarray)):
assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
"should be (1L,), but the current metrics.shape is {}. Maybe that " \
"you should call paddle.mean to process it first.".format(loss.shape)
elif not isinstance(metrics,
(int, float, numpy.float32, numpy.float64)):
raise TypeError(
"metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".
format(type(metrics)))
if self.cooldown_counter > 0:
self.cooldown_counter -= 1
else:
if self.best is None or self._is_better(metrics, self.best):
self.best = metrics
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.num_bad_epochs > self.patience:
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
new_lr = max(self.last_lr * self.factor, self.min_lr)
if self.last_lr - new_lr > self.epsilon:
self.last_lr = new_lr
if self.verbose:
print('Epoch {}: {} set learning rate to {}.'.format(
self.last_epoch, self.__class__.__name__,
self.last_lr))
def _is_better(self, current, best):
print("mode", self.mode, 'threshold_mode', self.threshold_mode)
if self.mode == 'min' and self.threshold_mode == 'rel':
return current < best - best * self.threshold
elif self.mode == 'min' and self.threshold_mode == 'abs':
return current < best - self.threshold
elif self.mode == 'max' and self.threshold_mode == 'rel':
return current > best + best * self.threshold
else:
return current > best + self.threshold
class CosineAnnealingLR(_LRScheduler):
"""
Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
SGDR:
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
The algorithm can be described as following.
.. math::
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
Note that this only implements the cosine annealing part of SGDR, and not the restarts.
Args:
learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate.
eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``CosineAnnealingLR`` instance to schedule learning rate.
Examples:
.. code-block:: python
import paddle
import numpy as np
# train on default dygraph mode
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.CosineAnnealingLR(learning_rate=0.5, T_max=10, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
scheduler.step()
# train on statich mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[-1, 4, 5])
y = paddle.static.data(name='y', shape=[-1, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.CosineAnnealingLR(learning_rate=0.5, T_max=10, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
lr_var = sgd._global_learning_rate()
exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scheduler.step()
"""
def __init__(self,
learning_rate,
T_max,
eta_min=0,
last_epoch=-1,
verbose=False):
if not isinstance(T_max, int):
raise TypeError(
"The type of 'T_max' in 'CosineAnnealingLR' must be 'int', but received %s."
% type(T_max))
if not isinstance(eta_min, (float, int)):
raise TypeError(
"The type of 'eta_min' in 'CosineAnnealingLR' must be 'float, int', but received %s."
% type(eta_min))
self.T_max = T_max
self.eta_min = float(eta_min)
super(CosineAnnealingLR, self).__init__(learning_rate, last_epoch,
verbose)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
math.pi / self.T_max)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
self.last_lr - self.eta_min) + self.eta_min
def _get_closed_form_lr(self):
return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
math.pi * self.last_epoch / self.T_max)) / 2
......@@ -21,6 +21,7 @@ __all__ = [
'load', 'data', 'InputSpec'
]
from . import nn
from .input import data #DEFINE_ALIAS
from .input import InputSpec #DEFINE_ALIAS
from ..fluid.executor import Executor #DEFINE_ALIAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册