未验证 提交 6bf75c13 编写于 作者: Z Zhou Wei 提交者: GitHub

cherry-pick,fix optimizer.state_dict and LRScheduler.state_dict to save/load dygraph (#25447)

cherry-pick,fix optimizer.state_dict and LRScheduler.state_dict to save/load dygraph
上级 316afbb2
...@@ -78,9 +78,9 @@ def save_dygraph(state_dict, model_path): ...@@ -78,9 +78,9 @@ def save_dygraph(state_dict, model_path):
for k, v in state_dict.items(): for k, v in state_dict.items():
if isinstance(v, (Variable, core.VarBase)): if isinstance(v, (Variable, core.VarBase)):
model_dict[k] = v.numpy() model_dict[k] = v.numpy()
name_table[k] = v.name
else: else:
model_dict[k] = v model_dict[k] = v
name_table[k] = v.name
model_dict["StructuredToParameterName@@"] = name_table model_dict["StructuredToParameterName@@"] = name_table
file_name = model_path + suffix file_name = model_path + suffix
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import math import math
import warnings
from .. import unique_name from .. import unique_name
from ..framework import Variable from ..framework import Variable
...@@ -66,6 +67,51 @@ class LearningRateDecay(object): ...@@ -66,6 +67,51 @@ class LearningRateDecay(object):
persistable=False) persistable=False)
return lr return lr
def state_dict(self):
"""
Returns the state of the scheduler as a :class:`dict`.
It is a subset of self.__dict__ .
"""
self._state_keys()
state_dict = {}
for key in self.keys:
if key not in self.__dict__:
continue
value = self.__dict__[key]
if isinstance(value, Variable):
assert value.shape == [
1
], "shape of Variable in state_dict must be [1] {}".format(
value.shape)
value = value.numpy()[0]
state_dict[key] = value
return state_dict
def _state_keys(self):
"""
set the keys in self.__dict__ that are needed to be saved.
"""
self.keys = ['step_num']
def set_dict(self, state_dict):
"""
Loads the schedulers state.
"""
self._state_keys()
for key in self.keys:
if key in state_dict:
self.__dict__[key] = state_dict[key]
else:
raise RuntimeError(
"Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
format(key))
if len(state_dict) > len(self.keys):
warnings.warn(
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
)
def step(self): def step(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -402,7 +448,7 @@ class PolynomialDecay(LearningRateDecay): ...@@ -402,7 +448,7 @@ class PolynomialDecay(LearningRateDecay):
learning_rate(Variable|float): The initial learning rate. If the type learning_rate(Variable|float): The initial learning rate. If the type
is Variable, it's a tensor with shape [1], the data type can be is Variable, it's a tensor with shape [1], the data type can be
float32 or float64. It also can be set to python int number. float32 or float64. It also can be set to python int number.
decay_steps(int32): The decay step size. It determines the decay cycle. decay_steps(int): The decay step size. It determines the decay cycle.
end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001. end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001.
power(float, optional): Power of polynomial. The default value is 1.0. power(float, optional): Power of polynomial. The default value is 1.0.
cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False. cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False.
...@@ -784,7 +830,7 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -784,7 +830,7 @@ class ReduceLROnPlateau(LearningRateDecay):
raise ValueError( raise ValueError(
'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.' 'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.'
) )
self.decay_rate = decay_rate self.decay_rate = self.create_lr_var(decay_rate)
threshold_mode = threshold_mode.lower() threshold_mode = threshold_mode.lower()
if threshold_mode not in ['rel', 'abs']: if threshold_mode not in ['rel', 'abs']:
...@@ -793,8 +839,10 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -793,8 +839,10 @@ class ReduceLROnPlateau(LearningRateDecay):
self.threshold_mode = threshold_mode self.threshold_mode = threshold_mode
check_type(learning_rate, 'learning_rate', (float, int, Variable), check_type(learning_rate, 'learning_rate', (float, int, Variable),
'ReduceLROnPlateau') 'ReduceLROnPlateau')
if isinstance(learning_rate, (float, int)): if not isinstance(learning_rate, (float, int, Variable)):
learning_rate = self.create_lr_var(learning_rate) raise TypeError(
"The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float, int, Variable', but received %s."
% type(learning_rate))
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.verbose = verbose self.verbose = verbose
...@@ -808,9 +856,17 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -808,9 +856,17 @@ class ReduceLROnPlateau(LearningRateDecay):
self.cooldown_counter = 0 self.cooldown_counter = 0
self.best_loss = None self.best_loss = None
self.num_bad_epochs = 0 self.num_bad_epochs = 0
self.epoch = 0 self.epoch_num = 0
def _state_keys(self):
self.keys = [
'cooldown_counter', 'best_loss', 'num_bad_epochs', 'epoch_num',
'learning_rate'
]
def __call__(self): def __call__(self):
if not isinstance(self.learning_rate, Variable):
self.learning_rate = self.create_lr_var(self.learning_rate)
return self.learning_rate return self.learning_rate
def step(self, loss): def step(self, loss):
...@@ -836,7 +892,7 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -836,7 +892,7 @@ class ReduceLROnPlateau(LearningRateDecay):
"should be (1L,), but the current loss.shape is {}. Maybe that " \ "should be (1L,), but the current loss.shape is {}. Maybe that " \
"you should call fluid.layers.mean to process it first.".format(loss.shape) "you should call fluid.layers.mean to process it first.".format(loss.shape)
self.epoch += 1 self.epoch_num += 1
if self.cooldown_counter > 0: if self.cooldown_counter > 0:
self.cooldown_counter -= 1 self.cooldown_counter -= 1
else: else:
...@@ -854,10 +910,11 @@ class ReduceLROnPlateau(LearningRateDecay): ...@@ -854,10 +910,11 @@ class ReduceLROnPlateau(LearningRateDecay):
self.decay_rate, self.min_lr) self.decay_rate, self.min_lr)
if self.learning_rate - new_lr > self.eps: if self.learning_rate - new_lr > self.eps:
if self.verbose: if self.verbose:
old_lr = self.learning_rate.numpy()[0] if isinstance(
self.learning_rate,
Variable) else self.learning_rate
print('Epoch {}: reducing learning rate from {} to {}.'. print('Epoch {}: reducing learning rate from {} to {}.'.
format(self.epoch, format(self.epoch_num, old_lr, new_lr.numpy()[0]))
self.learning_rate.numpy()[0],
new_lr.numpy()[0]))
self.learning_rate = new_lr self.learning_rate = new_lr
def _is_better(self, current, best): def _is_better(self, current, best):
...@@ -890,22 +947,28 @@ class _LearningRateEpochDecay(LearningRateDecay): ...@@ -890,22 +947,28 @@ class _LearningRateEpochDecay(LearningRateDecay):
raise TypeError( raise TypeError(
"The type of 'learning_rate' must be 'float, int', but received %s." "The type of 'learning_rate' must be 'float, int', but received %s."
% type(learning_rate)) % type(learning_rate))
if learning_rate >= 1.0: if learning_rate < 0:
raise ValueError("The initial learning rate") raise ValueError("Invalid learning rate: {}".format(learning_rate))
self.base_lr = float(learning_rate) self.base_lr = float(learning_rate)
self.epoch_num = -1 self.epoch_num = -1
self.dtype = dtype
if dtype is None: if dtype is None:
self.dtype = "float32" self.dtype = "float32"
self.learning_rate = self.create_lr_var(self.base_lr) self.learning_rate = self.create_lr_var(self.base_lr)
self.epoch() self.epoch()
def _state_keys(self):
self.keys = ['epoch_num', 'learning_rate']
def __call__(self): def __call__(self):
""" """
Return last computed learning rate on current epoch. Return last computed learning rate on current epoch.
""" """
if not isinstance(self.learning_rate, Variable):
self.learning_rate = self.create_lr_var(self.learning_rate)
return self.learning_rate return self.learning_rate
def epoch(self, epoch=None): def epoch(self, epoch=None):
...@@ -918,8 +981,6 @@ class _LearningRateEpochDecay(LearningRateDecay): ...@@ -918,8 +981,6 @@ class _LearningRateEpochDecay(LearningRateDecay):
self.epoch_num = epoch self.epoch_num = epoch
self.learning_rate = self.get_lr() self.learning_rate = self.get_lr()
if isinstance(self.learning_rate, float):
self.learning_rate = self.create_lr_var(self.learning_rate)
def get_lr(self): def get_lr(self):
raise NotImplementedError raise NotImplementedError
...@@ -946,7 +1007,7 @@ class StepDecay(_LearningRateEpochDecay): ...@@ -946,7 +1007,7 @@ class StepDecay(_LearningRateEpochDecay):
Parameters: Parameters:
learning_rate (float|int): The initial learning rate. It can be set to python float or int number. learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
step_size (int): Period of learning rate decay.. step_size (int): Period of learning rate decay.
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
It should be less than 1.0. Default: 0.1. It should be less than 1.0. Default: 0.1.
...@@ -1024,7 +1085,7 @@ class MultiStepDecay(_LearningRateEpochDecay): ...@@ -1024,7 +1085,7 @@ class MultiStepDecay(_LearningRateEpochDecay):
learning_rate = 0.005 learning_rate = 0.005
Parameters: Parameters:
learning_rate (float|int): The initial learning rate. It can be set to python float or int number. If it learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
milestones (tuple|list): List or tuple of each boundaries. Must be increasing. milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
It should be less than 1.0. Default: 0.1. It should be less than 1.0. Default: 0.1.
......
...@@ -33,7 +33,7 @@ from .layers import ops ...@@ -33,7 +33,7 @@ from .layers import ops
from .regularizer import append_regularization_ops from .regularizer import append_regularization_ops
from .dygraph import base as imperative_base from .dygraph import base as imperative_base
from .dygraph import no_grad from .dygraph import no_grad
from .dygraph.learning_rate_scheduler import LearningRateDecay from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layers import tensor from paddle.fluid.layers import tensor
from functools import reduce from functools import reduce
...@@ -148,17 +148,17 @@ class Optimizer(object): ...@@ -148,17 +148,17 @@ class Optimizer(object):
state_dict[var_tmp.name] = var_tmp state_dict[var_tmp.name] = var_tmp
# global step if use lr decay # global step if use lr decay
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
var_tmp = None state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
if framework.in_dygraph_mode():
if not isinstance(self._learning_rate, _LearningRateEpochDecay):
var_tmp = None
var_temp = framework._varbase_creator( var_temp = framework._varbase_creator(
None, name='global_step', dtype='int32') None, name='global_step', dtype='int32')
else:
var_temp = Variable(None, name='global_step', dtype='int32')
tensor.fill_constant( tensor.fill_constant(
[1], "int32", self._learning_rate.step_num, out=var_temp) [1], "int32", self._learning_rate.step_num, out=var_temp)
state_dict['global_step'] = var_temp state_dict['global_step'] = var_temp
return state_dict return state_dict
@framework.dygraph_only @framework.dygraph_only
...@@ -192,30 +192,28 @@ class Optimizer(object): ...@@ -192,30 +192,28 @@ class Optimizer(object):
''' '''
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
assert 'global_step' in state_dict, \ self._learning_rate.set_dict(state_dict["LR_Scheduler"])
'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict'
global_step = state_dict['global_step'] if not isinstance(self._learning_rate, _LearningRateEpochDecay):
assert 'global_step' in state_dict, \
if isinstance(global_step, core.VarBase): 'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict'
step_np = global_step global_step = state_dict['global_step']
step_np = np.array(step_np.value().get_tensor())
assert step_np.shape == (1,), \ if isinstance(global_step, Variable):
"global step shape is (1,), the shape is {}".format( step_np.shape ) step_np = global_step
step_np = np.array(step_np.value().get_tensor())
self._learning_rate.step_num = int(step_np[0]) assert step_np.shape == (1,), \
elif isinstance(global_step, Variable): "global step shape is (1,), the shape is {}".format( step_np.shape )
step_np = global_step.numpy()
assert step_np.shape == (1,), \ self._learning_rate.step_num = int(step_np[0])
"global step shape is (1,), the shape is {}".format( step_np.shape ) elif isinstance(global_step, np.ndarray):
self._learning_rate.step_num = step_np[0] assert global_step.shape == (1,), \
elif isinstance(global_step, np.ndarray): "global step shape is (1,), the shape is {}".format( global_step.shape )
assert global_step.shape == (1,), \ self._learning_rate.step_num = global_step[0]
"global step shape is (1,), the shape is {}".format( global_step.shape ) else:
self._learning_rate.step_num = global_step[0] raise RuntimeError(
else: "Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ",
raise RuntimeError( type(global_step))
"Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ",
type(global_step))
self._accumulators_holder = state_dict self._accumulators_holder = state_dict
for k, v in self._accumulators.items(): for k, v in self._accumulators.items():
...@@ -346,11 +344,14 @@ class Optimizer(object): ...@@ -346,11 +344,14 @@ class Optimizer(object):
""" """
current_lr = self._global_learning_rate() current_lr = self._global_learning_rate()
if current_lr: if isinstance(current_lr, framework.Variable):
return self._global_learning_rate().numpy()[0] return self._global_learning_rate().numpy()[0]
if isinstance(self._learning_rate, float): if isinstance(self._learning_rate, float):
return self._learning_rate return self._learning_rate
elif isinstance(self._learning_rate, _LearningRateEpochDecay):
step_lr = self._learning_rate()
return step_lr.numpy()[0]
else: else:
step_lr = self._learning_rate.step() step_lr = self._learning_rate.step()
if isinstance(step_lr, (float, int)): if isinstance(step_lr, (float, int)):
......
...@@ -276,8 +276,11 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -276,8 +276,11 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.opti_dict = adam.state_dict() self.opti_dict = adam.state_dict()
self.base_opti = {} self.base_opti = {}
for k, v in self.opti_dict.items(): for k, v in self.opti_dict.items():
self.base_opti[v.name] = v.numpy() if isinstance(v, core.VarBase):
self.assertTrue(np.sum(np.abs(v.numpy())) != 0) self.base_opti[v.name] = v.numpy()
self.assertTrue(np.sum(np.abs(v.numpy())) != 0)
else:
self.base_opti[k] = v
fluid.save_dygraph(self.opti_dict, "./test_dy") fluid.save_dygraph(self.opti_dict, "./test_dy")
...@@ -359,11 +362,12 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -359,11 +362,12 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
# set to zero # set to zero
for k, v in opti_dict.items(): for k, v in opti_dict.items():
np_t = v.numpy() if isinstance(v, core.VarBase):
var = v.value().get_tensor() np_t = v.numpy()
var.set(np.zeros_like(np_t), place) var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0) self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
if isinstance(adam._learning_rate, LearningRateDecay): if isinstance(adam._learning_rate, LearningRateDecay):
adam._learning_rate.step_num = 0 adam._learning_rate.step_num = 0
...@@ -373,8 +377,11 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -373,8 +377,11 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
for k, v in opti_dict.items(): for k, v in opti_dict.items():
self.assertTrue( if isinstance(v, core.VarBase):
np.array_equal(v.numpy(), self.base_opti[v.name])) self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
self.assertEqual(v, self.base_opti[k])
# check parameter # check parameter
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
...@@ -464,21 +471,24 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -464,21 +471,24 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
# set to zero # set to zero
for k, v in opti_dict.items(): for k, v in opti_dict.items():
np_t = v.numpy() if isinstance(v, core.VarBase):
var = v.value().get_tensor() np_t = v.numpy()
var.set(np.zeros_like(np_t), place) var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0) self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
if isinstance(adam._learning_rate, LearningRateDecay): if isinstance(adam._learning_rate, LearningRateDecay):
adam._learning_rate.step_num = 0 adam._learning_rate.step_num = 0
adam.set_dict(self.opti_dict) adam.set_dict(self.opti_dict)
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
for k, v in opti_dict.items(): for k, v in opti_dict.items():
self.assertTrue( if isinstance(v, core.VarBase):
np.array_equal(v.numpy(), self.base_opti[v.name])) self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
self.assertEqual(v, self.base_opti[k])
# check parameter # check parameter
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
...@@ -569,12 +579,14 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -569,12 +579,14 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_opti_dict = {} np_opti_dict = {}
# set to zero # set to zero
for k, v in opti_dict.items(): for k, v in opti_dict.items():
np_t = v.numpy() if isinstance(v, core.VarBase):
np_opti_dict[v.name] = np_t np_t = v.numpy()
var = v.value().get_tensor() np_opti_dict[v.name] = np_t
var.set(np.zeros_like(np_t), place) var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
self.assertTrue(np.sum(np.abs(v.numpy())) == 0) self.assertTrue(np.sum(np.abs(v.numpy())) == 0)
else:
np_opti_dict[k] = v
if isinstance(adam._learning_rate, LearningRateDecay): if isinstance(adam._learning_rate, LearningRateDecay):
adam._learning_rate.step_num = 0 adam._learning_rate.step_num = 0
...@@ -583,8 +595,11 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -583,8 +595,11 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
for k, v in opti_dict.items(): for k, v in opti_dict.items():
self.assertTrue( if isinstance(v, core.VarBase):
np.array_equal(v.numpy(), self.base_opti[v.name])) self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
self.assertEqual(v, self.base_opti[k])
# check parameter # check parameter
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
...@@ -825,7 +840,10 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -825,7 +840,10 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_state_dict = {} np_state_dict = {}
for k, v in self.opti_dict.items(): for k, v in self.opti_dict.items():
np_opti_dict[v.name] = v.numpy() if isinstance(v, core.VarBase):
np_opti_dict[v.name] = v.numpy()
else:
np_opti_dict[k] = v
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
np_state_dict[k] = v.numpy() np_state_dict[k] = v.numpy()
......
...@@ -121,6 +121,104 @@ def lambda_decay(global_step, learning_rate, lr_lambda): ...@@ -121,6 +121,104 @@ def lambda_decay(global_step, learning_rate, lr_lambda):
class TestLearningRateDecayDygraph(unittest.TestCase): class TestLearningRateDecayDygraph(unittest.TestCase):
def test_LR_state_dict(self):
with fluid.dygraph.guard():
x = np.random.uniform(-1, 1, [3, 10]).astype("float32")
linear = fluid.dygraph.Linear(10, 10)
input = fluid.dygraph.to_variable(x)
Exponential_scheduler = fluid.dygraph.ExponentialDecay(
learning_rate=0.1,
decay_steps=10000,
decay_rate=0.5,
staircase=True)
Step_scheduler = fluid.dygraph.StepDecay(0.5, step_size=3)
Reducelr_scheduler = fluid.dygraph.ReduceLROnPlateau(
learning_rate=1.0, decay_rate=0.5, patience=5, cooldown=3)
adam1 = fluid.optimizer.Adam(
learning_rate=Exponential_scheduler,
parameter_list=linear.parameters())
adam2 = fluid.optimizer.Adam(
learning_rate=Step_scheduler,
parameter_list=linear.parameters())
adam3 = fluid.optimizer.Adam(
learning_rate=Reducelr_scheduler,
parameter_list=linear.parameters())
print(adam3.state_dict())
for epoch in range(10):
out = linear(input)
loss = fluid.layers.reduce_mean(out)
loss.backward()
adam1.minimize(loss)
adam2.minimize(loss)
adam3.minimize(loss)
linear.clear_gradients()
Step_scheduler.epoch()
Reducelr_scheduler.step(loss)
fluid.dygraph.save_dygraph(linear.state_dict(), "save_path")
Exponential_scheduler_test = fluid.dygraph.ExponentialDecay(
learning_rate=0.1,
decay_steps=10000,
decay_rate=0.5,
staircase=True)
Step_scheduler_test = fluid.dygraph.StepDecay(0.5, step_size=3)
Reducelr_scheduler_test = fluid.dygraph.ReduceLROnPlateau(
learning_rate=1.0, decay_rate=0.5, patience=5, cooldown=3)
fluid.dygraph.save_dygraph(adam1.state_dict(), "save_path")
_, opt_state = fluid.dygraph.load_dygraph("save_path")
adam_test = fluid.optimizer.Adam(
learning_rate=Exponential_scheduler_test,
parameter_list=linear.parameters())
adam_test.set_dict(opt_state)
self.assertEqual(adam_test._learning_rate.step_num,
adam1._learning_rate.step_num,
"epoch_num is different before and after set_dict")
fluid.dygraph.save_dygraph(adam2.state_dict(), "save_path")
_, opt_state = fluid.dygraph.load_dygraph("save_path")
adam_test = fluid.optimizer.Adam(
learning_rate=Step_scheduler_test,
parameter_list=linear.parameters())
adam_test.set_dict(opt_state)
self.assertEqual(adam_test._learning_rate.epoch_num,
adam2._learning_rate.epoch_num,
"epoch_num is different before and after set_dict")
self.assertEqual(
adam_test._learning_rate(),
adam2._learning_rate(),
"current learning rate is different before and after set_dict")
fluid.dygraph.save_dygraph(adam3.state_dict(), "save_path")
_, opt_state = fluid.dygraph.load_dygraph("save_path")
adam_test = fluid.optimizer.Adam(
learning_rate=Reducelr_scheduler_test,
parameter_list=linear.parameters())
adam_test.set_dict(opt_state)
self.assertEqual(adam_test._learning_rate.best_loss,
adam3._learning_rate.best_loss.numpy()[0],
"best_loss is different before and after set_dict")
self.assertEqual(
adam_test._learning_rate.cooldown_counter,
adam3._learning_rate.cooldown_counter,
"cooldown_counter is different before and after set_dict")
self.assertEqual(
adam_test._learning_rate.num_bad_epochs,
adam3._learning_rate.num_bad_epochs,
"num_bad_epochs is different before and after set_dict")
self.assertEqual(adam_test._learning_rate.epoch_num,
adam3._learning_rate.epoch_num,
"epoch is different before and after set_dict")
self.assertEqual(
adam_test._learning_rate(),
adam3._learning_rate(),
"current learning rate is different before and after set_dict")
def test_NoamDecay(self): def test_NoamDecay(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
d_model = 0.01 d_model = 0.01
...@@ -169,17 +267,22 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -169,17 +267,22 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
learning_rate = 0.5 learning_rate = 0.5
milestones = [2, 4, 8] milestones = [2, 4, 8]
decay_rate = 0.2 decay_rate = 0.2
linear = fluid.dygraph.Linear(10, 10)
scheduler = fluid.dygraph.MultiStepDecay(learning_rate, milestones, scheduler = fluid.dygraph.MultiStepDecay(learning_rate, milestones,
decay_rate) decay_rate)
adam = fluid.optimizer.AdamOptimizer(
learning_rate=scheduler, parameter_list=linear.parameters())
for epoch in range(10): for epoch in range(10):
right_result = multi_step_decay(epoch, learning_rate, right_result = multi_step_decay(epoch, learning_rate,
milestones, decay_rate) milestones, decay_rate)
fluid_result = scheduler().numpy()[0] fluid_result = adam.current_step_lr()
scheduler.epoch() scheduler.epoch()
self.assertAlmostEqual( self.assertAlmostEqual(
right_result, right_result,
fluid_result, fluid_result,
msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
format(epoch, right_result, fluid_result)) format(epoch, right_result, fluid_result))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -190,6 +293,12 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -190,6 +293,12 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50], lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50],
1) 1)
with self.assertRaises(TypeError):
lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50])
with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(-1, [20, 30, 50])
def test_StepDecay(self): def test_StepDecay(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
learning_rate = 0.5 learning_rate = 0.5
...@@ -205,21 +314,14 @@ class TestLearningRateDecayDygraph(unittest.TestCase): ...@@ -205,21 +314,14 @@ class TestLearningRateDecayDygraph(unittest.TestCase):
self.assertAlmostEqual( self.assertAlmostEqual(
right_result, right_result,
fluid_result, fluid_result,
msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
format(epoch, right_result, fluid_result)) format(epoch, right_result, fluid_result))
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
lr = fluid.dygraph.MultiStepDecay(learning_rate, "test", 0.1) lr = fluid.dygraph.StepDecay(learning_rate, "test", 0.1)
with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50],
1)
with self.assertRaises(TypeError):
lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
lr = fluid.dygraph.MultiStepDecay(2.0, [20, 30, 50]) lr = fluid.dygraph.StepDecay(learning_rate, 20, 2)
def test_LambdaDecay(self): def test_LambdaDecay(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册