未验证 提交 407de039 编写于 作者: Z Zhou Wei 提交者: GitHub

[2.0API] Reconstruct all API related to LR Scheduler, unify dygraph and static (#26550)

* Reconstruct all API related to lr scheduler, unify dygraph and static

* Reconstruct all API related to lr scheduler, unify dygraph and static

* fix doc

* fix doc

* fix doc of lr_scheduler

* fix unittest and english doc

* fix english doc

* fix confilt

* fix doc
上级 6e823cfe
...@@ -850,6 +850,7 @@ class Executor(object): ...@@ -850,6 +850,7 @@ class Executor(object):
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy, return_merged): return_numpy, return_merged):
from paddle.optimizer.lr_scheduler import _LRScheduler
exe = program._executor exe = program._executor
# TODO(zhenghuihuang): quantization uses Graph in CompiledProgram # TODO(zhenghuihuang): quantization uses Graph in CompiledProgram
# instead of program. We will add support for checking Vars in Graph # instead of program. We will add support for checking Vars in Graph
...@@ -893,6 +894,16 @@ class Executor(object): ...@@ -893,6 +894,16 @@ class Executor(object):
res.append(res_dict) res.append(res_dict)
exe.feed_tensors_into_local_scopes(res) exe.feed_tensors_into_local_scopes(res)
if hasattr(program._program, 'lr_sheduler'):
lr_sheduler = program._program.lr_sheduler
assert isinstance(lr_sheduler, _LRScheduler), "must be _LRScheduler"
lr_value = lr_sheduler()
lr_var = program._program.global_block().vars[lr_sheduler._var_name]
lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype)
exe.feed_and_split_tensor_into_local_scopes({
lr_sheduler._var_name: lr_tensor
})
fetch_var_names = list(map(_to_name_str, fetch_list)) fetch_var_names = list(map(_to_name_str, fetch_list))
tensors = exe.run(fetch_var_names, return_merged)._move_to_list() tensors = exe.run(fetch_var_names, return_merged)._move_to_list()
return as_numpy(tensors) if return_numpy else tensors return as_numpy(tensors) if return_numpy else tensors
...@@ -1222,7 +1233,7 @@ class Executor(object): ...@@ -1222,7 +1233,7 @@ class Executor(object):
def _run_program(self, program, feed, fetch_list, feed_var_name, def _run_program(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache): fetch_var_name, scope, return_numpy, use_program_cache):
from paddle.optimizer.lr_scheduler import _LRScheduler
if feed is None: if feed is None:
feed = {} feed = {}
elif isinstance(feed, (list, tuple)): elif isinstance(feed, (list, tuple)):
...@@ -1278,6 +1289,16 @@ class Executor(object): ...@@ -1278,6 +1289,16 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
assert isinstance(program.lr_sheduler,
_LRScheduler), "must be _LRScheduler"
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope, lr_sheduler._var_name)
tensor.set(data, self.place)
if not use_program_cache: if not use_program_cache:
self._default_executor.run(program.desc, scope, 0, True, True, self._default_executor.run(program.desc, scope, 0, True, True,
fetch_var_name) fetch_var_name)
......
...@@ -4450,6 +4450,8 @@ class Program(object): ...@@ -4450,6 +4450,8 @@ class Program(object):
p._current_role = self._current_role p._current_role = self._current_role
p.__op_role_var = self.__op_role_var p.__op_role_var = self.__op_role_var
p._appending_grad_times = self._appending_grad_times p._appending_grad_times = self._appending_grad_times
if hasattr(self, 'lr_sheduler'):
p.lr_sheduler = self.lr_sheduler
#NOTE(zhiqiu): we sync the cloned program, to update its program by #NOTE(zhiqiu): we sync the cloned program, to update its program by
# its desc. # its desc.
......
...@@ -68,14 +68,16 @@ class Optimizer(object): ...@@ -68,14 +68,16 @@ class Optimizer(object):
regularization=None, regularization=None,
grad_clip=None, grad_clip=None,
name=None): name=None):
# Because of the loop import, so place it in the function body
from paddle.optimizer.lr_scheduler import _LRScheduler
self._parameter_list = list( self._parameter_list = list(
parameter_list) if parameter_list is not None else None parameter_list) if parameter_list is not None else None
self._name = name self._name = name
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate,
not isinstance(learning_rate, LearningRateDecay): (float, LearningRateDecay, _LRScheduler)):
raise TypeError( raise TypeError(
"learning rate should be float or LearningRateDecay, got %s here" "learning rate should be float or _LRScheduler, got %s here"
% type(learning_rate)) % type(learning_rate))
if self._parameter_list is None: if self._parameter_list is None:
raise AttributeError( raise AttributeError(
...@@ -90,11 +92,11 @@ class Optimizer(object): ...@@ -90,11 +92,11 @@ class Optimizer(object):
% regularization.__str__()) % regularization.__str__())
break break
else: else:
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate,
not isinstance(learning_rate, framework.Variable): (float, framework.Variable, _LRScheduler)):
raise TypeError( raise TypeError(
"learning rate should be float or Variable, got %s here" % "learning rate should be float or _LRScheduler, got %s here"
type(learning_rate)) % type(learning_rate))
if grad_clip is not None: if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase): if not isinstance(grad_clip, GradientClipBase):
...@@ -144,11 +146,15 @@ class Optimizer(object): ...@@ -144,11 +146,15 @@ class Optimizer(object):
state_dict = adam.state_dict() state_dict = adam.state_dict()
''' '''
from paddle.optimizer.lr_scheduler import _LRScheduler
state_dict = {} state_dict = {}
for k, v in self._accumulators.items(): for k, v in self._accumulators.items():
for para_name, var_tmp in v.items(): for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp state_dict[var_tmp.name] = var_tmp
# global step if use lr decay # global step if use lr decay
if isinstance(self._learning_rate, _LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
return state_dict
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict() state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
...@@ -192,6 +198,9 @@ class Optimizer(object): ...@@ -192,6 +198,9 @@ class Optimizer(object):
adam.set_dict(opti_state_dict) adam.set_dict(opti_state_dict)
''' '''
from paddle.optimizer.lr_scheduler import _LRScheduler
if isinstance(self._learning_rate, _LRScheduler):
self._learning_rate.set_dict(state_dict["LR_Scheduler"])
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
self._learning_rate.set_dict(state_dict["LR_Scheduler"]) self._learning_rate.set_dict(state_dict["LR_Scheduler"])
...@@ -252,6 +261,30 @@ class Optimizer(object): ...@@ -252,6 +261,30 @@ class Optimizer(object):
return self._opti_name_list return self._opti_name_list
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
from paddle.optimizer.lr_scheduler import _LRScheduler
if isinstance(self._learning_rate, _LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype='float32' if self._dtype is None else self._dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
self._learning_rate_map[framework.default_main_program(
)] = lr_var
lr_value = float(self._learning_rate())
self.helper.set_variable_initializer(
lr_var, initializer=Constant(value=lr_value))
return
if imperative_base.enabled(): if imperative_base.enabled():
# create learning rate Variable # create learning rate Variable
if isinstance(self._learning_rate, float): if isinstance(self._learning_rate, float):
......
...@@ -19,7 +19,10 @@ __all__ = [ ...@@ -19,7 +19,10 @@ __all__ = [
'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer', 'LambOptimizer', 'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer', 'LambOptimizer',
'LarsMomentum', 'LarsMomentumOptimizer', 'LookaheadOptimizer', 'LarsMomentum', 'LarsMomentumOptimizer', 'LookaheadOptimizer',
'ModelAverage', 'Momentum', 'MomentumOptimizer', 'PipelineOptimizer', 'ModelAverage', 'Momentum', 'MomentumOptimizer', 'PipelineOptimizer',
'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer' 'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer',
'_LRScheduler', 'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR',
'PolynomialLR', 'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR',
'LambdaLR', 'ReduceLROnPlateau', 'CosineAnnealingLR'
] ]
...@@ -36,3 +39,7 @@ from .adam import Adam ...@@ -36,3 +39,7 @@ from .adam import Adam
from .adamw import AdamW from .adamw import AdamW
from .adamax import Adamax from .adamax import Adamax
from .rmsprop import RMSProp from .rmsprop import RMSProp
from . import lr_scheduler
from .lr_scheduler import _LRScheduler, NoamLR, PiecewiseLR, NaturalExpLR, InverseTimeLR, PolynomialLR, \
LinearLrWarmup, ExponentialLR, MultiStepLR, StepLR, LambdaLR, ReduceLROnPlateau, CosineAnnealingLR
此差异已折叠。
...@@ -21,6 +21,7 @@ __all__ = [ ...@@ -21,6 +21,7 @@ __all__ = [
'load', 'data', 'InputSpec' 'load', 'data', 'InputSpec'
] ]
from . import nn
from .input import data #DEFINE_ALIAS from .input import data #DEFINE_ALIAS
from .input import InputSpec #DEFINE_ALIAS from .input import InputSpec #DEFINE_ALIAS
from ..fluid.executor import Executor #DEFINE_ALIAS from ..fluid.executor import Executor #DEFINE_ALIAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册