未验证 提交 30aab177 编写于 作者: Z Zhou Wei 提交者: GitHub

[2.0API]support 2.0 lr_scheduler for 2.0 optimizer (#26737)

* support 2.0 lr_scheduler for 2.0 optimizer

* fix unittest

* fix doc

* fix unittest

* fix sample code, fix unittest
上级 29494d70
......@@ -456,8 +456,9 @@ class TestAdamOpV2(unittest.TestCase):
state_dict = adam.state_dict()
adam.set_state_dict(state_dict)
#learning_rate is Decay
learning_rate = fluid.dygraph.CosineDecay(0.1, 10000, 120)
#learning_rate is _LRScheduler
learning_rate = paddle.optimizer.CosineAnnealingLR(
learning_rate=0.1, T_max=10)
adam = paddle.optimizer.Adam(
learning_rate=learning_rate,
weight_decay=fluid.regularizer.L2Decay(0.001),
......@@ -498,15 +499,10 @@ class TestAdamOpV2(unittest.TestCase):
adam.set_lr(lr)
cur_lr = adam.get_lr()
assert (lr == cur_lr)
lr_var = paddle.create_global_var(shape=[1], value=lr, dtype='float32')
adam.set_lr(lr_var)
cur_lr = adam.get_lr()
assert (np.float32(lr) == cur_lr)
with self.assertRaises(TypeError):
lr = int(1)
adam.set_lr(lr)
lr_var = paddle.create_global_var(
shape=[1], value=lr, dtype='float32')
adam.set_lr(lr_var)
if __name__ == "__main__":
......
......@@ -200,7 +200,7 @@ class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase):
def get_optimizer_dygraph(self, parameter_list):
bd = [3, 6, 9]
optimizer = SGDOptimizer(
learning_rate=fluid.layers.piecewise_decay(
learning_rate=paddle.optimizer.PiecewiseLR(
boundaries=bd,
values=[0.1 * (0.1**i) for i in range(len(bd) + 1)]),
parameter_list=parameter_list)
......@@ -208,7 +208,7 @@ class TestImperativeOptimizerPiecewiseDecay(TestImperativeOptimizerBase):
def get_optimizer(self):
bd = [3, 6, 9]
optimizer = SGDOptimizer(learning_rate=fluid.layers.piecewise_decay(
optimizer = SGDOptimizer(learning_rate=paddle.optimizer.PiecewiseLR(
boundaries=bd, values=[0.1 * (0.1**i) for i in range(len(bd) + 1)]))
return optimizer
......@@ -381,9 +381,9 @@ class TestOptimizerLearningRate(unittest.TestCase):
bd = [2, 4, 6, 8]
value = [0.2, 0.4, 0.6, 0.8, 1.0]
scheduler = paddle.optimizer.PiecewiseLR(bd, value)
adam = paddle.optimizer.Adam(
fluid.dygraph.PiecewiseDecay(bd, value, 0),
parameters=linear.parameters())
scheduler, parameters=linear.parameters())
self.assertTrue(
np.allclose(
......@@ -393,8 +393,8 @@ class TestOptimizerLearningRate(unittest.TestCase):
for i in range(12):
adam.minimize(loss)
lr = adam.get_lr()
self.assertTrue(np.allclose(lr, ret[i], rtol=1e-06, atol=0.0))
scheduler.step()
def test_lr_decay_natural_exp(self):
with fluid.dygraph.guard():
......@@ -409,24 +409,21 @@ class TestOptimizerLearningRate(unittest.TestCase):
loss = fluid.layers.reduce_mean(b)
base_lr = 1.0
scheduler = paddle.optimizer.NaturalExpLR(1.0, gamma=0.5)
print("scheduler.last_lr", scheduler.last_lr)
adam = paddle.optimizer.Adam(
fluid.dygraph.NaturalExpDecay(
learning_rate=base_lr,
decay_steps=3,
decay_rate=0.5,
staircase=True),
parameters=linear.parameters())
scheduler, parameters=linear.parameters())
self.assertTrue(
np.allclose(
adam.get_lr(), 1.0, rtol=1e-06, atol=0.0))
ret = [1.0, 1.0, 1.0, np.exp(-0.5), np.exp(-0.5)]
for i in range(5):
ret = [1.0, np.exp(-0.5), np.exp(-1)]
for i in range(3):
adam.minimize(loss)
lr = adam.get_lr()
self.assertTrue(np.allclose(lr, ret[i], rtol=1e-06, atol=0.0))
scheduler.step()
def test_set_lr(self):
with fluid.dygraph.guard():
......@@ -451,20 +448,15 @@ class TestOptimizerLearningRate(unittest.TestCase):
np.allclose(
lr, lr_list[i], rtol=1e-06, atol=0.0))
lr_var = fluid.layers.create_global_var(
shape=[1], value=0.7, dtype='float32')
adam.set_lr(lr_var)
adam.minimize(loss)
lr = adam.get_lr()
self.assertTrue(np.allclose(lr, 0.7, rtol=1e-06, atol=0.0))
with self.assertRaises(TypeError):
lr_var = fluid.layers.create_global_var(
shape=[1], value=0.7, dtype='float32')
adam.set_lr(lr_var)
with self.assertRaises(RuntimeError):
adam = paddle.optimizer.Adam(
fluid.dygraph.NaturalExpDecay(
learning_rate=0.1,
decay_steps=3,
decay_rate=0.5,
staircase=True),
paddle.optimizer.NaturalExpLR(
learning_rate=0.1, gamma=0.5),
parameters=linear.parameters())
adam.set_lr(0.01)
......
......@@ -374,6 +374,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
adam._learning_rate.step_num = 0
para_state_dict, opti_state_dict = paddle.load("./test_dy")
print(opti_state_dict['LR_Scheduler'])
adam.set_dict(opti_state_dict)
opti_dict = adam.state_dict()
......
......@@ -239,10 +239,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
scheduler = paddle.optimizer.PiecewiseLR(
boundaries=bd, values=lr_arr)
adam = Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr_arr),
parameters=ptb_model.parameters())
learning_rate=scheduler, parameters=ptb_model.parameters())
dy_param_updated = dict()
dy_param_init = dict()
dy_loss = None
......@@ -268,7 +268,9 @@ class TestDygraphPtbRnn(unittest.TestCase):
dy_param_init[param.name] = param.numpy()
dy_loss.backward()
adam.minimize(dy_loss)
scheduler.step()
ptb_model.clear_gradients()
if i == batch_num - 1:
for param in ptb_model.parameters():
dy_param_updated[param.name] = param.numpy()
......@@ -283,7 +285,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
else:
self.base_opti[k] = v
fluid.save_dygraph(self.opti_dict, "./test_dy")
fluid.save_dygraph(self.opti_dict, "./test_dy_v2")
self.state_dict = ptb_model.state_dict()
......@@ -292,7 +294,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_t = v.numpy()
self.model_base[k] = np_t
paddle.save(self.state_dict, "./test_dy")
paddle.save(self.state_dict, "./test_dy_v2")
def testLoadAndSetVarBase(self):
seed = 90
......@@ -325,10 +327,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
scheduler = paddle.optimizer.PiecewiseLR(
boundaries=bd, values=lr_arr)
adam = Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr_arr),
parameters=ptb_model.parameters())
learning_rate=scheduler, parameters=ptb_model.parameters())
dy_param_updated = dict()
dy_param_init = dict()
dy_loss = None
......@@ -354,6 +356,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
dy_param_init[param.name] = param.numpy()
dy_loss.backward()
adam.minimize(dy_loss)
scheduler.step()
ptb_model.clear_gradients()
if i == batch_num - 1:
for param in ptb_model.parameters():
......@@ -370,10 +373,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
if isinstance(adam._learning_rate, LearningRateDecay):
adam._learning_rate.step_num = 0
para_state_dict, opti_state_dict = paddle.load("./test_dy")
para_state_dict, opti_state_dict = paddle.load("./test_dy_v2")
adam.set_state_dict(opti_state_dict)
opti_dict = adam.state_dict()
......@@ -434,10 +434,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
scheduler = paddle.optimizer.PiecewiseLR(
boundaries=bd, values=lr_arr)
adam = Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr_arr),
parameters=ptb_model.parameters())
learning_rate=scheduler, parameters=ptb_model.parameters())
dy_param_updated = dict()
dy_param_init = dict()
dy_loss = None
......@@ -463,6 +463,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
dy_param_init[param.name] = param.numpy()
dy_loss.backward()
adam.minimize(dy_loss)
scheduler.step()
ptb_model.clear_gradients()
if i == batch_num - 1:
for param in ptb_model.parameters():
......@@ -541,10 +542,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
scheduler = paddle.optimizer.PiecewiseLR(
boundaries=bd, values=lr_arr)
adam = Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr_arr),
parameters=ptb_model.parameters())
learning_rate=scheduler, parameters=ptb_model.parameters())
dy_param_updated = dict()
dy_param_init = dict()
dy_loss = None
......@@ -570,6 +571,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
dy_param_init[param.name] = param.numpy()
dy_loss.backward()
adam.minimize(dy_loss)
scheduler.step()
ptb_model.clear_gradients()
if i == batch_num - 1:
for param in ptb_model.parameters():
......@@ -745,7 +747,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
last_hidden = None
last_cell = None
state_dict, opti_dict = fluid.load_dygraph("./test_dy")
state_dict, opti_dict = fluid.load_dygraph("./test_dy_v2")
adam.set_state_dict(opti_dict)
ptb_model.set_dict(state_dict)
......@@ -825,9 +827,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
scheduler = paddle.optimizer.PiecewiseLR(
boundaries=bd, values=lr_arr)
adam = Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr_arr),
learning_rate=scheduler,
beta1=0.8,
beta2=0.6,
parameters=ptb_model.parameters())
......@@ -867,14 +870,16 @@ class TestDygraphPtbRnn(unittest.TestCase):
init_cell)
dy_loss.backward()
scheduler.step()
adam.minimize(dy_loss)
ptb_model.clear_gradients()
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if k == "global_step":
if k == "LR_Scheduler":
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name] + 1))
np.array_equal(v['last_epoch'], self.base_opti[k][
'last_epoch'] + 1))
if k.find("beta1_pow_acc_0") > 0:
self.assertTrue(
......
......@@ -523,491 +523,5 @@ class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase):
run_places(lr, start_lr, end_lr)
def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
var_list):
def is_better(current, best, m, n):
if m == 'min' and n == 'rel':
return current < best - best * threshold
elif m == 'min' and n == 'abs':
return current < best - threshold
elif m == 'max' and n == 'rel':
return current > best + best * threshold
else: # mode == 'max' and epsilon_mode == 'abs':
return current > best + threshold
if var_list[2] > 0:
var_list[2] -= 1
return var_list[1]
if is_better(loss, var_list[0], m, n):
var_list[0] = loss
var_list[3] = 0
else:
var_list[3] += 1
if var_list[3] > patience:
var_list[2] = cooldown
var_list[3] = 0
new_lr = var_list[1] * decay_rate
var_list[1] = new_lr if var_list[1] - new_lr > 1e-8 else var_list[1]
return var_list[1]
class TestReduceLROnPlateauDecay(unittest.TestCase):
def test_ReduceLR(self):
# the decay rate must be less than 1.0
with self.assertRaises(ValueError):
paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=2.0)
# the mode must be "min" or "max"
with self.assertRaises(ValueError):
paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, mode="test")
# the threshold_mode must be "rel" or "abs"
with self.assertRaises(ValueError):
paddle.optimizer.ReduceLROnPlateau(
learning_rate=1.0, threshold_mode="test")
with self.assertRaises(TypeError):
paddle.optimizer.ReduceLROnPlateau(learning_rate="test")
with self.assertRaises(TypeError):
paddle.optimizer.ReduceLROnPlateau(learning_rate=0.5).step("test")
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
for m, n in zip(['min', 'max', 'min', 'max'],
['rel', 'rel', 'abs', 'abs']):
kwargs = {
'learning_rate': 1.0,
'mode': m,
'factor': 0.5,
'patience': 3,
'threshold': 1e-4,
'threshold_mode': n,
'cooldown': 1,
'min_lr': 0,
'epsilon': 1e-8,
'verbose': False,
}
paddle.enable_static()
self._test_static(place, kwargs)
paddle.disable_static(place)
self._test_dygraph(place, kwargs)
paddle.enable_static()
def _test_static(self, place, kwargs):
paddle.enable_static()
best = float("-10000") if kwargs['mode'] == "max" else float("10000")
current_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, current_lr, cooldown_counter, num_bad_epochs]
main_prog = fluid.Program()
start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.layers.create_global_var(
[1], 1, 'float32', persistable=True)
paddle.increment(x)
loss = paddle.sin(x)
scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
adam = fluid.optimizer.Adam(learning_rate=scheduler)
adam.minimize(loss)
lr_var = adam._global_learning_rate()
test_prog = main_prog.clone()
exe = fluid.Executor(place)
exe.run(start_prog)
for epoch in range(20):
for batch_id in range(1):
out, actual_lr = exe.run(main_prog,
fetch_list=[loss.name, lr_var.name])
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'],
kwargs['threshold_mode'], out[0], var_list)
scheduler.step(out[0])
actual_lr = scheduler()
self.assertEqual(actual_lr, np.array(expected_lr))
for epoch in range(10):
for batch_id in range(1):
out, actual_lr = exe.run(test_prog,
fetch_list=[loss.name, lr_var.name])
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'],
kwargs['threshold_mode'], out[0], var_list)
scheduler.step(out[0])
actual_lr = scheduler()
self.assertEqual(actual_lr, np.array(expected_lr))
def _test_dygraph(self, place, kwargs):
paddle.disable_static(place)
best = float("-10000") if kwargs['mode'] == "max" else float("10000")
current_lr = 1.0
cooldown_counter = 0
num_bad_epochs = 0
var_list = [best, current_lr, cooldown_counter, num_bad_epochs]
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
sgd = paddle.optimizer.SGD(learning_rate=scheduler,
parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(1):
x = paddle.to_tensor(epoch).astype('float32')
loss = paddle.sin(x)
loss.backward()
sgd.minimize(loss)
scheduler.step(loss)
# get lr from paddle
current_lr = scheduler()
# get lr form python
expected_lr = reduce_lr_on_plateau(
kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
kwargs['patience'], kwargs['mode'], kwargs['threshold_mode'],
loss, var_list)
self.assertEqual(current_lr, expected_lr)
state_dict = sgd.state_dict()
scheduler1 = paddle.optimizer.ReduceLROnPlateau(**kwargs)
sgd1 = paddle.optimizer.SGD(learning_rate=scheduler1,
parameter_list=linear.parameters())
sgd1.set_dict(state_dict)
self.assertEqual(scheduler.cooldown_counter,
scheduler1.cooldown_counter)
self.assertEqual(scheduler.best.numpy()[0], scheduler1.best)
self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
def noam_lr(epoch_num, d_model, warmup_steps, learning_rate=1.0, verbose=False):
if epoch_num == 0:
a = 1
else:
a = math.pow(epoch_num, -0.5)
b = math.pow(warmup_steps, -1.5) * epoch_num
return learning_rate * math.pow(d_model, -0.5) * min(a, b)
def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
return learning_rate * lr_lambda(epoch_num)
def piecewise_lr(epoch_num, boundaries, values, verbose=False):
assert len(boundaries) + 1 == len(values)
for i in range(len(boundaries)):
if epoch_num < boundaries[i]:
return values[i]
return values[len(values) - 1]
def exponential_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate * gamma**epoch_num
def natural_exp_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate * math.exp(-1 * gamma * epoch_num)
def inverse_time_lr(epoch_num, learning_rate, gamma, verbose=False):
return learning_rate / (1 + gamma * epoch_num)
def polynomial_lr(epoch_num,
learning_rate,
decay_steps,
end_lr=0.0001,
power=1.0,
cycle=False,
verbose=False):
if cycle:
div = math.ceil(epoch_num / float(decay_steps))
if epoch_num == 0:
div = 1
decay_steps = decay_steps * div
else:
epoch_num = min(epoch_num, decay_steps)
return (learning_rate - end_lr) * (
(1 - float(epoch_num) / float(decay_steps))**power) + end_lr
def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
math.pi / self.T_max)) / 2
return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
self.last_lr - self.eta_min) + self.eta_min
cosine_annealing_lr_current = None
def cosine_annealing_lr(epoch_num,
learning_rate,
T_max,
eta_min=0,
verbose=False):
global cosine_annealing_lr_current
if epoch_num == 0:
cosine_annealing_lr_current = learning_rate
elif (epoch_num - 1 - T_max) % (2 * T_max) == 0:
cosine_annealing_lr_current = cosine_annealing_lr_current + (
learning_rate - eta_min) * (1 - math.cos(math.pi / float(T_max))
) / 2
else:
cosine_annealing_lr_current = (1 + math.cos(
math.pi * epoch_num / float(T_max))) / (1 + math.cos(math.pi * (
epoch_num - 1) / float(T_max))) * (cosine_annealing_lr_current -
eta_min) + eta_min
return cosine_annealing_lr_current
def linear_warmup_lr(epoch_num,
learning_rate,
warmup_steps,
start_lr,
end_lr,
verbose=False):
if epoch_num < warmup_steps:
return start_lr + (end_lr - start_lr) * (float(epoch_num) /
float(warmup_steps))
else:
return learning_rate
def multi_step_lr(epoch_num,
learning_rate,
milestones,
gamma=0.1,
verbose=False):
for i in range(len(milestones)):
if epoch_num < milestones[i]:
return learning_rate * (gamma**i)
return learning_rate * (gamma**len(milestones))
def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
return learning_rate * math.pow(gamma, epoch_num // step_size)
class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
main_prog = fluid.Program()
start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.data(name='x', shape=[3, 4, 5])
y = fluid.data(name='y', shape=[3, 4, 5])
z = fluid.layers.fc(x, 100)
loss = fluid.layers.mean(z)
scheduler = paddle_api(**kwarg)
adam = fluid.optimizer.Adam(learning_rate=scheduler)
adam.minimize(loss)
lr_var = adam._global_learning_rate()
test_prog = main_prog.clone()
num = 0
exe = fluid.Executor(place)
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
num += 1
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
test_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
num += 1
if isinstance(place, fluid.CPUPlace):
compiled_train_prog = fluid.CompiledProgram(
main_prog).with_data_parallel(
loss_name=loss.name, places=fluid.cpu_places(4))
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_train_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_train_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[1].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[2].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[3].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
scheduler.step()
num += 1
compiled_test_prog = fluid.CompiledProgram(
test_prog).with_data_parallel(
loss_name=loss.name,
share_vars_from=compiled_train_prog,
places=fluid.cpu_places(4))
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_test_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_test_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[1].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[2].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
out = np.array(scopes[3].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
scheduler.step()
num += 1
def _test_dygraph(self, python_func, paddle_api, kwarg, place):
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle_api(**kwarg)
sgd = paddle.optimizer.SGD(learning_rate=scheduler,
parameter_list=linear.parameters())
for epoch in range(20):
for batch_id in range(2):
x = paddle.to_tensor(x)
out = linear(x)
loss = paddle.reduce_mean(out)
out.backward()
sgd.minimize(loss)
linear.clear_gradients()
self.assertAlmostEqual(sgd.current_step_lr(),
python_func(epoch, **kwarg))
if paddle_api.__name__ != "CosineAnnealingLR":
scheduler.step()
else:
scheduler.step(epoch + 1)
def test_scheduler(self):
with self.assertRaises(NotImplementedError):
paddle.optimizer.lr_scheduler._LRScheduler().step()
with self.assertRaises(TypeError):
paddle.optimizer.MultiStepLR(
learning_rate="test", milestones=[1, 2, 3])
with self.assertRaises(TypeError):
paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones='test')
with self.assertRaises(ValueError):
paddle.optimizer.MultiStepLR(
learning_rate=0.5, milestones=[3, 2, 1])
with self.assertRaises(ValueError):
paddle.optimizer.MultiStepLR(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
func_api_kwargs = [(noam_lr, paddle.optimizer.NoamLR, {
"d_model": 0.01,
"warmup_steps": 100,
"verbose": False
}), (piecewise_lr, paddle.optimizer.PiecewiseLR, {
"boundaries": [3, 6, 9, 15, 20],
"values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"verbose": False
}), (natural_exp_lr, paddle.optimizer.NaturalExpLR, {
"learning_rate": 0.5,
"gamma": 0.1,
"verbose": False
}), (inverse_time_lr, paddle.optimizer.InverseTimeLR, {
"learning_rate": 0.5,
"gamma": 0.1,
"verbose": True
}), (polynomial_lr, paddle.optimizer.PolynomialLR, {
"learning_rate": 0.5,
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": False,
"verbose": False
}), (polynomial_lr, paddle.optimizer.PolynomialLR, {
"learning_rate": 0.5,
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": True,
"verbose": False
}), (linear_warmup_lr, paddle.optimizer.LinearLrWarmup, {
'learning_rate': 0.5,
'warmup_steps': 20,
'start_lr': 0,
'end_lr': 0.5,
"verbose": False
}), (exponential_lr, paddle.optimizer.ExponentialLR, {
"learning_rate": 0.5,
"gamma": 0.9,
"verbose": False
}), (multi_step_lr, paddle.optimizer.MultiStepLR, {
"learning_rate": 0.5,
"milestones": [3, 6, 9, 15, 20],
"gamma": 0.8,
"verbose": True
}), (step_lr, paddle.optimizer.StepLR, {
"learning_rate": 0.5,
"step_size": 2,
"gamma": 0.8,
"verbose": False
}), (lambda_lr, paddle.optimizer.LambdaLR, {
"learning_rate": 0.5,
"lr_lambda": lambda x: 0.95**x,
"verbose": False
}), (cosine_annealing_lr, paddle.optimizer.CosineAnnealingLR, {
"learning_rate": 0.5,
"T_max": 10,
"verbose": True
})]
for python_func, paddle_api, kwarg in func_api_kwargs:
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
paddle.enable_static()
self._test_static(python_func, paddle_api, kwarg, place)
paddle.disable_static(place)
self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -35,12 +35,12 @@ from ..fluid.layers import ops
from ..fluid.regularizer import append_regularization_ops
from ..fluid.dygraph import base as imperative_base
from ..fluid.dygraph import no_grad
from ..fluid.dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
from ..fluid.wrapped_decorator import signature_safe_contextmanager
from .. import compat as cpt
from .lr_scheduler import _LRScheduler
__all__ = ['Optimizer']
......@@ -53,8 +53,8 @@ class Optimizer(object):
but need to use one of it's implementation.
Args:
learning_rate (float|LearningRateDecay): The learning rate used to update ``Parameter``.
It can be a float value or a LearningRateDecay.
learning_rate (float|_LRScheduler): The learning rate used to update ``Parameter``.
It can be a float value or any subclass of ``_LRScheduler`` .
parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
......@@ -109,11 +109,6 @@ class Optimizer(object):
parameters) if parameters is not None else None
self._name = name
if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, LearningRateDecay):
raise TypeError(
"learning rate should be float or LearningRateDecay, got %s here"
% type(learning_rate))
if self._parameter_list is None:
raise AttributeError(
"parameters argument given to the Optimizer should not be None in dygraph mode."
......@@ -126,13 +121,10 @@ class Optimizer(object):
"The weight_decay[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
% weight_decay.__str__())
break
else:
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError(
"learning rate should be float or Tensor, got %s here" %
type(learning_rate))
if not isinstance(learning_rate, (float, _LRScheduler)):
raise TypeError(
"learning rate should be float or _LRScheduler, got %s here" %
type(learning_rate))
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
raise TypeError(
......@@ -150,9 +142,6 @@ class Optimizer(object):
# each program should have a independent learning rate
# program -> tensor(learning_rate)
self._learning_rate_map = dict()
if isinstance(self._learning_rate, framework.Variable):
self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate
# Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra tensors associated with the parameters
# to train. These tensors are called accumulators.
......@@ -167,7 +156,7 @@ class Optimizer(object):
@framework.dygraph_only
def state_dict(self):
'''
Get state dict information from optimizer. It contain all the tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be include in state dict.
Get state dict information from optimizer. It contain all the tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. If _LRScheduler have been used, global_step will be include in state dict.
If the optimizer never be called(minimize function), the state_dict is empty.
Args:
......@@ -192,24 +181,14 @@ class Optimizer(object):
for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp
# global step if use lr decay
if isinstance(self._learning_rate, LearningRateDecay):
if isinstance(self._learning_rate, _LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
if not isinstance(self._learning_rate, _LearningRateEpochDecay):
var_tmp = None
var_temp = framework._varbase_creator(
None, name='global_step', dtype='int32')
tensor.fill_constant(
[1], "int32", self._learning_rate.step_num, out=var_temp)
state_dict['global_step'] = var_temp
return state_dict
@framework.dygraph_only
def set_state_dict(self, state_dict):
'''
Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be changed.
Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If _LRScheduler have been used, global_step will be changed.
Args:
state_dict(dict) : Dict contains all the Tensor needed by optimizer
......@@ -226,7 +205,7 @@ class Optimizer(object):
state_dict = emb.state_dict()
paddle.framework.save(state_dict, "paddle_dy")
adam = paddle.optimizer.Adam(learning_rate=paddle.nn.functional.noam_decay( 100, 10000),
adam = paddle.optimizer.Adam(learning_rate=paddle.optimizer.NoamLR( 100, 10000),
parameters=emb.parameters())
state_dict = adam.state_dict()
paddle.framework.save(state_dict, "paddle_dy")
......@@ -237,29 +216,8 @@ class Optimizer(object):
'''
if isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate.set_dict(state_dict["LR_Scheduler"])
if not isinstance(self._learning_rate, _LearningRateEpochDecay):
assert 'global_step' in state_dict, \
'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict'
global_step = state_dict['global_step']
if isinstance(global_step, Variable):
step_np = global_step
step_np = np.array(step_np.value().get_tensor())
assert step_np.shape == (1,), \
"global step shape is (1,), the shape is {}".format( step_np.shape )
self._learning_rate.step_num = int(step_np[0])
elif isinstance(global_step, np.ndarray):
assert global_step.shape == (1,), \
"global step shape is (1,), the shape is {}".format( global_step.shape )
self._learning_rate.step_num = global_step[0]
else:
raise RuntimeError(
"Type not supprt, value in state dict must be [VarBase, Tensor, numpy], the type is ",
type(global_step))
if isinstance(self._learning_rate, _LRScheduler):
self._learning_rate.set_state_dict(state_dict["LR_Scheduler"])
self._accumulators_holder = state_dict
for k, v in self._accumulators.items():
......@@ -296,58 +254,49 @@ class Optimizer(object):
return self._opti_name_list
def _create_global_learning_rate(self):
if imperative_base.enabled():
# create learning rate tensor
if isinstance(self._learning_rate, float):
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
else:
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype,
persistable=True)
# get learning rate Tensor from LearningRateDecay
elif isinstance(self._learning_rate, LearningRateDecay):
if isinstance(self._learning_rate, _LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate()
else:
raise TypeError(
"optimizer's learning rate must be float or LearningRateDecay"
)
else:
lr = self._global_learning_rate()
)] = lr_var
lr_value = float(self._learning_rate())
self.helper.set_variable_initializer(
lr_var, initializer=Constant(value=lr_value))
elif isinstance(self._learning_rate, float):
# only create global lr_var once
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
else:
if not isinstance(self._learning_rate, float):
raise TypeError(
"learning rate Tensor is create outside optimizer,"
"can not create new learning rate Tensor for new program"
)
# create learning rate in the current main program
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype,
persistable=True)
self._learning_rate_map[framework.default_main_program(
)] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype,
persistable=True)
@framework.dygraph_only
def set_lr(self, value):
"""
:api_attr: imperative
Set the value of the learning rate manually in the optimizer. If the optimizer use LearningRateDecay,
Set the value of the learning rate manually in the optimizer. If the optimizer use _LRScheduler,
this API cannot be invoked, because it will lead to conflict.
Args:
......@@ -378,53 +327,36 @@ class Optimizer(object):
# current lr is 0.5
# current lr is 0.6
# set learning rate manually by framework Tensor
lr_var = paddle.create_global_var(
shape=[1], value=0.7, dtype='float32')
adam.set_lr(lr_var)
lr = adam.get_lr()
print("current lr is {}".format(lr))
# Print:
# current lr is 0.7
"""
if not isinstance(value, (framework.Variable, float)):
if not isinstance(value, (int, float)):
raise TypeError(
"The type of 'value' in optimizer.set_lr must be (float, Tensor), but received %s."
"The type of 'value' in optimizer.set_lr must be float, but received %s."
% (type(value)))
if isinstance(self._learning_rate, LearningRateDecay):
if isinstance(self._learning_rate, _LRScheduler):
raise RuntimeError(
"optimizer's learning rate can't be LearningRateDecay when invoke this API, because this will lead to conflict."
"optimizer's learning rate can't be _LRScheduler when invoke this API, because this will lead to conflict."
)
if isinstance(value, float):
self._learning_rate = value
current_lr = self._global_learning_rate()
if current_lr is not None:
global_block = framework.default_main_program().global_block()
global_block.append_op(
type='fill_constant',
outputs={'Out': [current_lr]},
attrs={
'dtype': current_lr.dtype,
'shape': list(current_lr.shape),
'value': float(value)
},
stop_gradient=True)
else:
assert len(value.shape) == 1 and value.shape[
0] == 1, "optimizer's learning rate must be 1-D Tensor with shape[1]"
self._learning_rate_map[framework.default_main_program()] = value
self._learning_rate = float(value)
current_lr = self._global_learning_rate()
if current_lr is not None:
global_block = framework.default_main_program().global_block()
global_block.append_op(
type='fill_constant',
outputs={'Out': [current_lr]},
attrs={
'dtype': current_lr.dtype,
'shape': list(current_lr.shape),
'value': float(value)
},
stop_gradient=True)
@framework.dygraph_only
def get_lr(self):
"""
:api_attr: imperative
Get current step learning rate. The return value is all the same When LearningRateDecay is not used,
otherwise return the step learning rate.
Get current step learning rate. The return value is all the same When _LRScheduler is not used,
otherwise return the current step learning rate.
Returns:
float: The learning rate of the current step.
......@@ -434,14 +366,14 @@ class Optimizer(object):
import numpy as np
import paddle
# example1: LearningRateDecay is not used, return value is all the same
# example1: _LRScheduler is not used, return value is all the same
paddle.disable_static()
emb = paddle.nn.Embedding([10, 10])
adam = paddle.optimizer.Adam(0.001, parameters = emb.parameters())
lr = adam.get_lr()
print(lr) # 0.001
# example2: PiecewiseDecay is used, return the step learning rate
# example2: PiecewiseLR is used, return the step learning rate
paddle.disable_static()
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
......@@ -451,7 +383,8 @@ class Optimizer(object):
bd = [2, 4, 6, 8]
value = [0.2, 0.4, 0.6, 0.8, 1.0]
adam = paddle.optimizer.Adam(paddle.PiecewiseDecay(bd, value, 0),
scheduler = paddle.optimizer.PiecewiseLR(bd, value, 0)
adam = paddle.optimizer.Adam(scheduler,
parameters=linear.parameters())
# first step: learning rate is 0.2
......@@ -462,24 +395,14 @@ class Optimizer(object):
for i in range(12):
adam.step()
lr = adam.get_lr()
scheduler.step()
np.allclose(lr, ret[i], rtol=1e-06, atol=0.0) # True
"""
current_lr = self._global_learning_rate()
if isinstance(current_lr, framework.Variable):
return self._global_learning_rate().numpy()[0]
if isinstance(self._learning_rate, float):
return self._learning_rate
elif isinstance(self._learning_rate, _LearningRateEpochDecay):
step_lr = self._learning_rate()
return step_lr.numpy()[0]
else:
step_lr = self._learning_rate.step()
if isinstance(step_lr, (float, int)):
return step_lr
else:
return step_lr.numpy()[0]
return self._learning_rate()
def _global_learning_rate(self, program=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册