diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 103a996443e30ddd6691c3f07d47d4cbc7db03ed..06549901e17b167e214298d40d994e3c3cbec6d7 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -42,7 +42,6 @@ from .framework import disable_static # noqa: F401 from .framework import enable_static # noqa: F401 from .framework import in_dynamic_mode # noqa: F401 from .fluid.dataset import * # noqa: F401, F403 -from .fluid.lazy_init import LazyGuard # noqa: F401 from .framework.dtype import iinfo # noqa: F401 from .framework.dtype import finfo # noqa: F401 @@ -437,6 +436,7 @@ import paddle.text # noqa: F401 import paddle.vision # noqa: F401 from .tensor.random import check_shape # noqa: F401 +from .nn.initializer.lazy_init import LazyGuard # noqa: F401 # CINN has to set a flag to include a lib if is_compiled_with_cinn(): diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index f3be337fedb77315ddbff492216591817cc7b8ef..4889bf5f701f6b1a03d3535b16f24ee60ce116c5 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -113,7 +113,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): p2s = self.create_snapshot_vars(main_block.program) with program_guard(main_block.program, startup_program): - step = paddle.fluid.layers.autoincreased_step_counter(begin=1) + step = paddle.optimizer.lr.autoincreased_step_counter(begin=1) k_steps = paddle.static.create_global_var( name="k_steps", shape=[1], @@ -330,7 +330,7 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): p2s = self.create_snapshot_vars(main_block.program) with program_guard(main_block.program, startup_program): - step = paddle.fluid.layers.autoincreased_step_counter(begin=1) + step = paddle.optimizer.lr.autoincreased_step_counter(begin=1) k_steps = paddle.static.create_global_var( name="k_steps", diff --git a/python/paddle/distributed/passes/ps_server_pass.py b/python/paddle/distributed/passes/ps_server_pass.py index 4e4377f328f3dd42132e85bba969f0ab361fff78..c68746366f48daeb9532aeb252b88db7f4f77128 100755 --- a/python/paddle/distributed/passes/ps_server_pass.py +++ b/python/paddle/distributed/passes/ps_server_pass.py @@ -15,17 +15,15 @@ import logging import paddle -from paddle.fluid.layers.learning_rate_scheduler import ( - exponential_decay, - inverse_time_decay, - noam_decay, -) from paddle.optimizer.lr import ( ExponentialDecay, InverseTimeDecay, LRScheduler, NaturalExpDecay, NoamDecay, + exponential_decay, + inverse_time_decay, + noam_decay, ) from ..ps.utils.public import ( diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 6eb88d8f8ef3dae50ad6ec25fcc77d55210b65a1..5eead87a995c94d9abc92efe9d946ffc31f2ae7b 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -21,7 +21,6 @@ from .framework import ( default_main_program, _current_expected_place, ) -from .lazy_init import lazy_init_helper from .framework import program_guard import numpy as np from .core import VarDesc diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index c5eb01ff763835d28762b4be97a0b3dcf027da47..9c6ce9aed0892393a480c1d5ab3d4e8481b213bc 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -12,17 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import nn -from .nn import * from . import io from .io import * from . import math_op_patch from .math_op_patch import * -from .learning_rate_scheduler import * from ..layer_helper import LayerHelper __all__ = [] -__all__ += nn.__all__ __all__ += io.__all__ -__all__ += learning_rate_scheduler.__all__ diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py deleted file mode 100644 index 59f25c63b744a8b9bd865ea2b0ae0d28da59c1d4..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ /dev/null @@ -1,604 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When training a model, it's often useful to decay the -learning rate during training process, this is called -learning_rate_decay. There are many strategies to do -this, this module will provide some classical method. -User can also implement their own learning_rate_decay -strategy according to this module. -""" - -import math -import numbers - -import paddle -from . import nn -from ..framework import ( - default_main_program, - Parameter, - unique_name, - name_scope, - in_dygraph_mode, -) -from ..framework import Variable -from ..dygraph import learning_rate_scheduler as imperate_lr -from ..data_feeder import check_variable_and_dtype, check_type - -__all__ = [ - 'exponential_decay', - 'natural_exp_decay', - 'inverse_time_decay', - 'polynomial_decay', - 'piecewise_decay', - 'noam_decay', - 'cosine_decay', - 'linear_lr_warmup', -] - - -def _decay_step_counter(begin=0): - # the first global step is zero in learning rate decay - global_step = nn.autoincreased_step_counter( - counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1 - ) - global_step = paddle.cast(global_step, 'float32') - return global_step - - -def noam_decay(d_model, warmup_steps, learning_rate=1.0): - """ - - Noam decay method. The numpy implementation of noam decay as follows. - - .. code-block:: python - - import paddle.fluid as fluid - import numpy as np - # set hyper parameters - base_lr = 0.01 - d_model = 2 - current_steps = 20 - warmup_steps = 200 - # compute - lr_value = base_lr * np.power(d_model, -0.5) * np.min([ - np.power(current_steps, -0.5), - np.power(warmup_steps, -1.5) * current_steps]) - - Please reference `attention is all you need - `_. - - Args: - d_model(Variable): The dimensionality of input and output of model. - - warmup_steps(Variable): A super parameter. - - learning_rate(Variable|float|int): The initial learning rate. If the type - is Variable, it's a tensor with shape [1], the data type can be - float32 or float64. It also can be set to python int number. Default 1.0 - - Returns: - The decayed learning rate. - Examples: - .. code-block:: python - - import paddle.fluid as fluid - warmup_steps = 100 - learning_rate = 0.01 - lr = fluid.layers.learning_rate_scheduler.noam_decay( - 1/(warmup_steps *(learning_rate ** 2)), - warmup_steps, - learning_rate) - """ - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - decay = paddle.optimizer.lr.NoamDecay( - d_model, warmup_steps, learning_rate=learning_rate - ) - return decay - else: - global_step = _decay_step_counter(1) - - a = global_step**-0.5 - b = (warmup_steps**-1.5) * global_step - lr_value = learning_rate * (d_model**-0.5) * paddle.minimum(a, b) - - return lr_value - - -def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): - """ - - Applies exponential decay to the learning rate. - - When training a model, it is often recommended to lower the learning rate as the - training progresses. By using this function, the learning rate will be decayed by - 'decay_rate' every 'decay_steps' steps. - - Decayed learning rate calculates as follows: - - >>> if staircase == True: - >>> decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps) - >>> else: - >>> decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) - - Args: - learning_rate(Variable|float): The initial learning rate. It should be a Variable - or a float - decay_steps(int): The learning rate decay steps. See the decay computation above. - decay_rate(float): The learning rate decay rate. See the decay computation above. - staircase(bool): If True, decay the learning rate at discrete intervals, which - means the learning rate will be decayed by `decay_rate` every - `decay_steps`. If False, learning rate will be decayed continuously - and following the formula above. Default: False - - Returns: - Variable: The decayed learning rate. The data type is float32. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - - paddle.enable_static() - base_lr = 0.1 - sgd_optimizer = fluid.optimizer.SGD( - learning_rate=fluid.layers.exponential_decay( - learning_rate=base_lr, - decay_steps=10000, - decay_rate=0.5, - staircase=True)) - - """ - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - decay = paddle.optimizer.lr.ExponentialDecay( - learning_rate, decay_rate - ) - return decay - else: - global_step = _decay_step_counter() - - div_res = global_step / decay_steps - if staircase: - div_res = paddle.floor(div_res) - decayed_lr = learning_rate * (decay_rate**div_res) - - return decayed_lr - - -def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): - """ - - Applies natural exponential decay to the initial learning rate. - - When training a model, it is often recommended to lower the learning rate as the - training progresses. By using this function, the learning rate will be decayed by - natural exponential power 'decay_rate' every 'decay_steps' steps. - - Decayed learning rate calculates as follows: - - >>> if not staircase: - >>> decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps)) - >>> else: - >>> decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps)) - - Args: - learning_rate(Variable|float): The initial learning rate. It should be a Variable - or a float - decay_steps(int): The learning rate decay steps. See the decay computation above. - decay_rate(float): The learning rate decay rate. See the decay computation above. - staircase(bool): If True, decay the learning rate at discrete intervals, which - means the learning rate will be decayed by natural exponential power - `decay_rate` every `decay_steps`. If False, learning rate will be - decayed continuously and following the formula above. Default: False - - Returns: - The decayed learning rate. The data type is float32. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - - paddle.enable_static() - base_lr = 0.1 - sgd_optimizer = fluid.optimizer.SGD( - learning_rate=fluid.layers.natural_exp_decay( - learning_rate=base_lr, - decay_steps=10000, - decay_rate=0.5, - staircase=True)) - - """ - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - decay = paddle.optimizer.lr.NaturalExpDecay( - learning_rate, decay_rate - ) - return decay - else: - global_step = _decay_step_counter() - - div_res = global_step / decay_steps - if staircase: - div_res = paddle.floor(div_res) - decayed_lr = learning_rate * paddle.exp(-1 * decay_rate * div_res) - - return decayed_lr - - -def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): - """ - - Applies inverse time decay to the initial learning rate. - - When training a model, it is often recommended to lower the learning rate as the - training progresses. By using this function, an inverse decay function will be - applied to the initial learning rate. - - Decayed learning rate calculates as follows: - - >>> if staircase == True: - >>> decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) - >>> else: - >>> decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step) - - Args: - learning_rate(Variable|float): The initial learning rate. It should be a Variable - or a float - decay_steps(int): The learning rate decay steps. See the decay computation above. - decay_rate(float): The learning rate decay rate. See the decay computation above. - staircase(bool): If True, decay the learning rate at discrete intervals, which - means the learning rate will be decayed by `decay_rate` times - every `decay_steps`. If False, learning rate will be decayed - continuously and following the formula above. Default: False - - Returns: - Variable: The decayed learning rate. The data type is float32. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - base_lr = 0.1 - sgd_optimizer = fluid.optimizer.SGD( - learning_rate=fluid.layers.inverse_time_decay( - learning_rate=base_lr, - decay_steps=10000, - decay_rate=0.5, - staircase=True)) - """ - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - decay = paddle.optimizer.lr.InverseTimeDecay( - learning_rate, decay_rate - ) - return decay - else: - global_step = _decay_step_counter() - - div_res = global_step / decay_steps - if staircase: - div_res = paddle.floor(div_res) - - decayed_lr = learning_rate / (1 + decay_rate * div_res) - - return decayed_lr - - -def polynomial_decay( - learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False -): - """ - Applies polynomial decay to the initial learning rate. - - .. code-block:: text - - if cycle: - decay_steps = decay_steps * ceil(global_step / decay_steps) - else: - global_step = min(global_step, decay_steps) - decayed_learning_rate = (learning_rate - end_learning_rate) * - (1 - global_step / decay_steps) ^ power + end_learning_rate - - Args: - learning_rate(Variable|float32): A scalar float32 value or a Variable. This - will be the initial learning rate during training. - decay_steps(int32): A Python `int32` number. - end_learning_rate(float): A Python `float` number. - power(float): A Python `float` number. - cycle(bool): If set true, decay the learning rate every decay_steps. - - Returns: - Variable: The decayed learning rate - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - start_lr = 0.01 - total_step = 5000 - end_lr = 0 - lr = fluid.layers.polynomial_decay( - start_lr, total_step, end_lr, power=1) - - """ - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - decay = paddle.optimizer.lr.PolynomialDecay( - learning_rate, decay_steps, end_learning_rate, power, cycle - ) - return decay - else: - global_step = _decay_step_counter() - - if cycle: - div_res = paddle.ceil(global_step / decay_steps) - zero_var = paddle.tensor.fill_constant( - shape=[1], dtype='float32', value=0.0 - ) - one_var = paddle.tensor.fill_constant( - shape=[1], dtype='float32', value=1.0 - ) - - div_val = paddle.static.nn.cond( - global_step == zero_var, lambda: one_var, lambda: div_res - ) - paddle.assign(div_val, output=div_res) - - decay_steps = decay_steps * div_res - else: - decay_steps_var = paddle.tensor.fill_constant( - shape=[1], dtype='float32', value=float(decay_steps) - ) - global_step = paddle.minimum(x=global_step, y=decay_steps_var) - - decayed_lr = (learning_rate - end_learning_rate) * ( - (1 - global_step / decay_steps) ** power - ) + end_learning_rate - return decayed_lr - - -def piecewise_decay(boundaries, values): - """ - - Applies piecewise decay to the initial learning rate. - - The algorithm can be described as the code below. - - .. code-block:: text - - boundaries = [10000, 20000] - values = [1.0, 0.5, 0.1] - if step < 10000: - learning_rate = 1.0 - elif 10000 <= step < 20000: - learning_rate = 0.5 - else: - learning_rate = 0.1 - Args: - boundaries: A list of steps numbers. - values: A list of learning rate values that will be picked during - different step boundaries. - - Returns: - The decayed learning rate. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - boundaries = [10000, 20000] - values = [1.0, 0.5, 0.1] - optimizer = paddle.optimizer.Momentum( - momentum=0.9, - learning_rate=paddle.optimizer.lr.PiecewiseDecay(boundaries, values), - weight_decay=paddle.regularizer.L2Decay(1e-4)) - - - """ - with default_main_program()._lr_schedule_guard(): - if len(values) - len(boundaries) != 1: - raise ValueError("len(values) - len(boundaries) should be 1") - - if in_dygraph_mode(): - decay = paddle.optimizer.lr.PiecewiseDecay(boundaries, values) - return decay - else: - global_step = _decay_step_counter() - - lr = paddle.static.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate", - ) - with paddle.static.nn.control_flow.Switch() as switch: - for i in range(len(boundaries)): - boundary_val = paddle.tensor.fill_constant( - shape=[1], - dtype='float32', - value=float(boundaries[i]), - force_cpu=True, - ) - with switch.case(global_step < boundary_val): - paddle.tensor.fill_constant( - shape=[1], - dtype="float32", - value=float(values[i]), - out=lr, - ) - with switch.default(): - paddle.tensor.fill_constant( - shape=[1], - dtype="float32", - value=float(values[len(values) - 1]), - out=lr, - ) - return lr - - -def cosine_decay(learning_rate, step_each_epoch, epochs): - r""" - - Applies cosine decay to the learning rate. - - when training a model, it is often recommended to lower the learning rate as the - training progresses. By using this function, the learning rate will be decayed by - following cosine decay strategy. - - .. math:: - - decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1) - - Args: - learning_rate(Variable|float): The initial learning rate. - step_each_epoch(int): the number of steps in an epoch. - epochs(int): the number of epochs. - - Returns: - Variable: The decayed learning rate. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - base_lr = 0.1 - lr = fluid.layers.cosine_decay( - learning_rate = base_lr, step_each_epoch=10000, epochs=120) - """ - check_type( - learning_rate, 'learning_rate', (float, Variable), 'cosine_decay' - ) - - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - decay = paddle.optimizer.lr.CosineAnnealingDecay( - learning_rate, epochs - ) - return decay - else: - global_step = _decay_step_counter() - - cur_epoch = paddle.floor(global_step / step_each_epoch) - decayed_lr = ( - learning_rate - * 0.5 - * (paddle.cos(cur_epoch * math.pi / epochs) + 1) - ) - return decayed_lr - - -def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): - """ - - This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling. - For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ - - When global_step < warmup_steps, learning rate is updated as: - - .. code-block:: text - - linear_step = end_lr - start_lr - lr = start_lr + linear_step * (global_step / warmup_steps) - - where start_lr is the initial learning rate, and end_lr is the final learning rate; - - When global_step >= warmup_steps, learning rate is updated as: - - .. code-block:: text - - lr = learning_rate - - where lr is the learning_rate after warm-up. - - Args: - learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32. - warmup_steps (int): Steps for warm up. - start_lr (float): Initial learning rate of warm up. - end_lr (float): Final learning rate of warm up. - - Returns: - Variable: Warm-up learning rate with the same data type as learning_rate. - - - Examples: - - .. code-block:: python - - import paddle.fluid as fluid - - boundaries = [100, 200] - lr_steps = [0.1, 0.01, 0.001] - learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor - #learning_rate = 0.1 #case2, single-value - warmup_steps = 50 - start_lr = 1. / 3. - end_lr = 0.1 - decayed_lr = fluid.layers.linear_lr_warmup(learning_rate, - warmup_steps, start_lr, end_lr) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - out, = exe.run(fetch_list=[decayed_lr.name]) - print(out) - # case1: [0.33333334] - # case2: [0.33333334] - """ - dtype = 'float32' - if isinstance(learning_rate, Variable): - dtype = learning_rate.dtype - - linear_step = float(end_lr) - float(start_lr) - with default_main_program()._lr_schedule_guard(): - if in_dygraph_mode(): - lr = paddle.optimizer.lr.LinearWarmup( - learning_rate, warmup_steps, start_lr, end_lr - ) - return lr - else: - lr = paddle.static.create_global_var( - shape=[1], - value=0.0, - dtype=dtype, - persistable=True, - name="learning_rate_warmup", - ) - - global_step = _decay_step_counter() - if not isinstance(learning_rate, Variable): - learning_rate = paddle.tensor.fill_constant( - shape=[1], dtype=dtype, value=float(learning_rate) - ) - lr_val = paddle.static.nn.case( - pred_fn_pairs=[ - ( - global_step < warmup_steps, - lambda: start_lr - + linear_step * (global_step / float(warmup_steps)), - ) - ], - default=lambda: learning_rate, - ) - paddle.assign(lr_val, lr) - return lr diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py deleted file mode 100644 index a4a770a97829a06c9ff5957235d5dea2e1c6134d..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/layers/nn.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -All layers just related to the neural network. -""" -import os -import inspect -import warnings - -import numpy as np - -import paddle -from ..layer_helper import LayerHelper -from ..framework import ( - Variable, - OpProtoHolder, - dygraph_only, - _dygraph_tracer, - default_main_program, - _create_tensor, - static_only, - _global_flags, - in_dygraph_mode, -) -from ..framework import _current_expected_place -from .. import dygraph_utils -from ..param_attr import ParamAttr -from .layer_function_generator import ( - autodoc, - templatedoc, - _generate_doc_string_, -) - -from .. import unique_name -from .. import core -from ...utils import deprecated -from ..data_feeder import ( - convert_dtype, - check_variable_and_dtype, - check_type, - check_dtype, -) -from paddle.utils import deprecated -from paddle import _C_ops, _legacy_C_ops -from collections.abc import Iterable - - -__all__ = [ - 'autoincreased_step_counter', -] - - -def autoincreased_step_counter(counter_name=None, begin=1, step=1): - """ - :api_attr: Static Graph - - Create an auto-increase variable. which will be automatically increased - by 1 in every iteration. By default, the first return of this counter is 1, - and the step size is 1. - - Args: - counter_name(str, optional): The counter name. Default '@STEP_COUNTER@'. - begin(int, optional): The first return value of this counter. Default 1. - step(int, optional): The step size. Default 1. - - Returns: - Variable: The auto-increased Variable with data type int64. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - global_step = fluid.layers.autoincreased_step_counter( - counter_name='@LR_DECAY_COUNTER@', begin=0, step=1) - """ - helper = LayerHelper('global_step_counter') - if counter_name is None: - counter_name = '@STEP_COUNTER@' - counter, is_new_var = helper.create_or_get_global_variable( - name=counter_name, - dtype='int64', - shape=[1], - persistable=True, - belong_to_optimizer=True, - ) - if is_new_var: - helper.set_variable_initializer( - counter, - initializer=paddle.nn.initializer.ConstantInitializer( - value=begin - 1, force_cpu=True - ), - ) - helper.main_program.global_block()._prepend_op( - type='increment', - inputs={'X': [counter]}, - outputs={'Out': [counter]}, - attrs={'step': float(step)}, - ) - counter.stop_gradient = True - - return counter diff --git a/python/paddle/incubate/distributed/fleet/parameter_server/ir/public.py b/python/paddle/incubate/distributed/fleet/parameter_server/ir/public.py index 8fc55869f54f3f7b8a41e983e89a14ca7c4ec1c2..75d65dc079e096b4b4946c0c530b961d61084eb8 100755 --- a/python/paddle/incubate/distributed/fleet/parameter_server/ir/public.py +++ b/python/paddle/incubate/distributed/fleet/parameter_server/ir/public.py @@ -1409,8 +1409,6 @@ def _get_lr_scheduler_program(lr_scheduler, lr_param_dict, lr_decay_steps): InverseTimeDecay, NaturalExpDecay, NoamDecay, - ) - from paddle.static.learning_rate_scheduler import ( exponential_decay, inverse_time_decay, natural_exp_decay, diff --git a/python/paddle/nn/initializer/initializer.py b/python/paddle/nn/initializer/initializer.py index 7d04e8d7cbc71d01b5516392d984fadb016986d4..9d5880aa0956133f06cea5fc9bec35b7035c1ad4 100644 --- a/python/paddle/nn/initializer/initializer.py +++ b/python/paddle/nn/initializer/initializer.py @@ -18,7 +18,7 @@ import math import numpy as np from ...fluid.framework import default_main_program, in_dygraph_mode -from ...fluid.lazy_init import lazy_init_helper +from .lazy_init import lazy_init_helper __all__ = [] @@ -42,7 +42,7 @@ class Initializer: return self._lazy_init(param, block) def forward(self, param, block=None): - """Add corresponding initialization operations to the network""" + """Add corresponding initialization operations to the network.""" raise NotImplementedError() def _lazy_init(self, param, block=None): diff --git a/python/paddle/fluid/lazy_init.py b/python/paddle/nn/initializer/lazy_init.py similarity index 99% rename from python/paddle/fluid/lazy_init.py rename to python/paddle/nn/initializer/lazy_init.py index 36f36161e6f27d60fb203f3ab1b8c6887c5d50e0..e2321f682f77ec1faa8b7f7556c1160879408f58 100644 --- a/python/paddle/fluid/lazy_init.py +++ b/python/paddle/nn/initializer/lazy_init.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import framework +from ...fluid import framework __all__ = ["LazyGuard"] diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index e628509e52afc1b94ce3171b22eeb15a10581501..0e7b8fe7353396f726b547dc147ffc9c1e073fe0 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -17,8 +17,16 @@ import warnings import numpy +import paddle from paddle import Tensor from paddle.fluid import core +from paddle.fluid.data_feeder import check_type +from paddle.fluid.framework import ( + Variable, + default_main_program, + in_dygraph_mode, +) +from paddle.fluid.layer_helper import LayerHelper __all__ = [ # noqa 'LRScheduler', @@ -2227,3 +2235,599 @@ class CyclicLR(LRScheduler): lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode)) return lr + + +def autoincreased_step_counter(counter_name=None, begin=1, step=1): + """ + :api_attr: Static Graph + + Create an auto-increase variable. which will be automatically increased + by 1 in every iteration. By default, the first return of this counter is 1, + and the step size is 1. + + Args: + counter_name(str, optional): The counter name. Default '@STEP_COUNTER@'. + begin(int, optional): The first return value of this counter. Default 1. + step(int, optional): The step size. Default 1. + + Returns: + Variable: The auto-increased Variable with data type int64. + + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + global_step = paddle.optimizer.lr.autoincreased_step_counter( + counter_name='@LR_DECAY_COUNTER@', begin=0, step=1) + """ + helper = LayerHelper('global_step_counter') + if counter_name is None: + counter_name = '@STEP_COUNTER@' + counter, is_new_var = helper.create_or_get_global_variable( + name=counter_name, + dtype='int64', + shape=[1], + persistable=True, + belong_to_optimizer=True, + ) + if is_new_var: + helper.set_variable_initializer( + counter, + initializer=paddle.nn.initializer.ConstantInitializer( + value=begin - 1, force_cpu=True + ), + ) + helper.main_program.global_block()._prepend_op( + type='increment', + inputs={'X': [counter]}, + outputs={'Out': [counter]}, + attrs={'step': float(step)}, + ) + counter.stop_gradient = True + + return counter + + +def _decay_step_counter(begin=0): + # the first global step is zero in learning rate decay + global_step = autoincreased_step_counter( + counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1 + ) + global_step = paddle.cast(global_step, 'float32') + return global_step + + +def noam_decay(d_model, warmup_steps, learning_rate=1.0): + """ + + Noam decay method. The numpy implementation of noam decay as follows. + + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + # set hyper parameters + base_lr = 0.01 + d_model = 2 + current_steps = 20 + warmup_steps = 200 + # compute + lr_value = base_lr * np.power(d_model, -0.5) * np.min([ + np.power(current_steps, -0.5), + np.power(warmup_steps, -1.5) * current_steps]) + + Please reference `attention is all you need + `_. + + Args: + d_model(Variable): The dimensionality of input and output of model. + + warmup_steps(Variable): A super parameter. + + learning_rate(Variable|float|int): The initial learning rate. If the type + is Variable, it's a tensor with shape [1], the data type can be + float32 or float64. It also can be set to python int number. Default 1.0 + + Returns: + The decayed learning rate. + Examples: + .. code-block:: python + + import paddle + warmup_steps = 100 + learning_rate = 0.01 + lr = paddle.optimizer.lr.noam_decay( + 1/(warmup_steps *(learning_rate ** 2)), + warmup_steps, + learning_rate) + """ + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + decay = paddle.optimizer.lr.NoamDecay( + d_model, warmup_steps, learning_rate=learning_rate + ) + return decay + else: + global_step = _decay_step_counter(1) + + a = global_step**-0.5 + b = (warmup_steps**-1.5) * global_step + lr_value = learning_rate * (d_model**-0.5) * paddle.minimum(a, b) + + return lr_value + + +def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): + """ + + Applies exponential decay to the learning rate. + + When training a model, it is often recommended to lower the learning rate as the + training progresses. By using this function, the learning rate will be decayed by + 'decay_rate' every 'decay_steps' steps. + + Decayed learning rate calculates as follows: + + >>> if staircase == True: + >>> decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps) + >>> else: + >>> decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) + + Args: + learning_rate(Variable|float): The initial learning rate. It should be a Variable + or a float + decay_steps(int): The learning rate decay steps. See the decay computation above. + decay_rate(float): The learning rate decay rate. See the decay computation above. + staircase(bool): If True, decay the learning rate at discrete intervals, which + means the learning rate will be decayed by `decay_rate` every + `decay_steps`. If False, learning rate will be decayed continuously + and following the formula above. Default: False + + Returns: + Variable: The decayed learning rate. The data type is float32. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + base_lr = 0.1 + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=paddle.optimizer.lr.exponential_decay( + learning_rate=base_lr, + decay_steps=10000, + decay_rate=0.5, + staircase=True)) + + """ + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + decay = ExponentialDecay(learning_rate, decay_rate) + return decay + else: + global_step = _decay_step_counter() + + div_res = global_step / decay_steps + if staircase: + div_res = paddle.floor(div_res) + decayed_lr = learning_rate * (decay_rate**div_res) + + return decayed_lr + + +def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): + """ + + Applies natural exponential decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as the + training progresses. By using this function, the learning rate will be decayed by + natural exponential power 'decay_rate' every 'decay_steps' steps. + + Decayed learning rate calculates as follows: + + >>> if not staircase: + >>> decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps)) + >>> else: + >>> decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps)) + + Args: + learning_rate(Variable|float): The initial learning rate. It should be a Variable + or a float + decay_steps(int): The learning rate decay steps. See the decay computation above. + decay_rate(float): The learning rate decay rate. See the decay computation above. + staircase(bool): If True, decay the learning rate at discrete intervals, which + means the learning rate will be decayed by natural exponential power + `decay_rate` every `decay_steps`. If False, learning rate will be + decayed continuously and following the formula above. Default: False + + Returns: + The decayed learning rate. The data type is float32. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle + + paddle.enable_static() + base_lr = 0.1 + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=paddle.optimizer.lr.natural_exp_decay( + learning_rate=base_lr, + decay_steps=10000, + decay_rate=0.5, + staircase=True)) + + """ + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + decay = NaturalExpDecay(learning_rate, decay_rate) + return decay + else: + global_step = _decay_step_counter() + + div_res = global_step / decay_steps + if staircase: + div_res = paddle.floor(div_res) + decayed_lr = learning_rate * paddle.exp(-1 * decay_rate * div_res) + + return decayed_lr + + +def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): + """ + + Applies inverse time decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as the + training progresses. By using this function, an inverse decay function will be + applied to the initial learning rate. + + Decayed learning rate calculates as follows: + + >>> if staircase == True: + >>> decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) + >>> else: + >>> decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step) + + Args: + learning_rate(Variable|float): The initial learning rate. It should be a Variable + or a float + decay_steps(int): The learning rate decay steps. See the decay computation above. + decay_rate(float): The learning rate decay rate. See the decay computation above. + staircase(bool): If True, decay the learning rate at discrete intervals, which + means the learning rate will be decayed by `decay_rate` times + every `decay_steps`. If False, learning rate will be decayed + continuously and following the formula above. Default: False + + Returns: + Variable: The decayed learning rate. The data type is float32. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle + paddle.enable_static() + base_lr = 0.1 + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=paddle.optimizer.lr.inverse_time_decay( + learning_rate=base_lr, + decay_steps=10000, + decay_rate=0.5, + staircase=True)) + """ + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + decay = InverseTimeDecay(learning_rate, decay_rate) + return decay + else: + global_step = _decay_step_counter() + + div_res = global_step / decay_steps + if staircase: + div_res = paddle.floor(div_res) + + decayed_lr = learning_rate / (1 + decay_rate * div_res) + + return decayed_lr + + +def polynomial_decay( + learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False +): + """ + Applies polynomial decay to the initial learning rate. + + .. code-block:: text + + if cycle: + decay_steps = decay_steps * ceil(global_step / decay_steps) + else: + global_step = min(global_step, decay_steps) + decayed_learning_rate = (learning_rate - end_learning_rate) * + (1 - global_step / decay_steps) ^ power + end_learning_rate + + Args: + learning_rate(Variable|float32): A scalar float32 value or a Variable. This + will be the initial learning rate during training. + decay_steps(int32): A Python `int32` number. + end_learning_rate(float): A Python `float` number. + power(float): A Python `float` number. + cycle(bool): If set true, decay the learning rate every decay_steps. + + Returns: + Variable: The decayed learning rate + + Examples: + .. code-block:: python + + import paddle + start_lr = 0.01 + total_step = 5000 + end_lr = 0 + lr = paddle.optimizer.lr.polynomial_decay( + start_lr, total_step, end_lr, power=1) + + """ + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + decay = PolynomialDecay( + learning_rate, decay_steps, end_learning_rate, power, cycle + ) + return decay + else: + global_step = _decay_step_counter() + + if cycle: + div_res = paddle.ceil(global_step / decay_steps) + zero_var = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=0.0 + ) + one_var = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=1.0 + ) + + div_val = paddle.static.nn.cond( + global_step == zero_var, lambda: one_var, lambda: div_res + ) + paddle.assign(div_val, output=div_res) + + decay_steps = decay_steps * div_res + else: + decay_steps_var = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=float(decay_steps) + ) + global_step = paddle.minimum(x=global_step, y=decay_steps_var) + + decayed_lr = (learning_rate - end_learning_rate) * ( + (1 - global_step / decay_steps) ** power + ) + end_learning_rate + return decayed_lr + + +def piecewise_decay(boundaries, values): + """ + + Applies piecewise decay to the initial learning rate. + + The algorithm can be described as the code below. + + .. code-block:: text + + boundaries = [10000, 20000] + values = [1.0, 0.5, 0.1] + if step < 10000: + learning_rate = 1.0 + elif 10000 <= step < 20000: + learning_rate = 0.5 + else: + learning_rate = 0.1 + Args: + boundaries: A list of steps numbers. + values: A list of learning rate values that will be picked during + different step boundaries. + + Returns: + The decayed learning rate. + + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + boundaries = [10000, 20000] + values = [1.0, 0.5, 0.1] + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=paddle.optimizer.lr.PiecewiseDecay(boundaries, values), + weight_decay=paddle.regularizer.L2Decay(1e-4)) + + + """ + with default_main_program()._lr_schedule_guard(): + if len(values) - len(boundaries) != 1: + raise ValueError("len(values) - len(boundaries) should be 1") + + if in_dygraph_mode(): + decay = PiecewiseDecay(boundaries, values) + return decay + else: + global_step = _decay_step_counter() + + lr = paddle.static.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate", + ) + with paddle.static.nn.control_flow.Switch() as switch: + for i in range(len(boundaries)): + boundary_val = paddle.tensor.fill_constant( + shape=[1], + dtype='float32', + value=float(boundaries[i]), + force_cpu=True, + ) + with switch.case(global_step < boundary_val): + paddle.tensor.fill_constant( + shape=[1], + dtype="float32", + value=float(values[i]), + out=lr, + ) + with switch.default(): + paddle.tensor.fill_constant( + shape=[1], + dtype="float32", + value=float(values[len(values) - 1]), + out=lr, + ) + return lr + + +def cosine_decay(learning_rate, step_each_epoch, epochs): + r""" + + Applies cosine decay to the learning rate. + + when training a model, it is often recommended to lower the learning rate as the + training progresses. By using this function, the learning rate will be decayed by + following cosine decay strategy. + + .. math:: + + decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1) + + Args: + learning_rate(Variable|float): The initial learning rate. + step_each_epoch(int): the number of steps in an epoch. + epochs(int): the number of epochs. + + Returns: + Variable: The decayed learning rate. + + Examples: + .. code-block:: python + + import paddle + base_lr = 0.1 + lr = paddle.optimizer.lr.cosine_decay( + learning_rate = base_lr, step_each_epoch=10000, epochs=120) + """ + check_type( + learning_rate, 'learning_rate', (float, Variable), 'cosine_decay' + ) + + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + decay = CosineAnnealingDecay(learning_rate, epochs) + return decay + else: + global_step = _decay_step_counter() + + cur_epoch = paddle.floor(global_step / step_each_epoch) + decayed_lr = ( + learning_rate + * 0.5 + * (paddle.cos(cur_epoch * math.pi / epochs) + 1) + ) + return decayed_lr + + +def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): + """ + + This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling. + For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ + + When global_step < warmup_steps, learning rate is updated as: + + .. code-block:: text + + linear_step = end_lr - start_lr + lr = start_lr + linear_step * (global_step / warmup_steps) + + where start_lr is the initial learning rate, and end_lr is the final learning rate; + + When global_step >= warmup_steps, learning rate is updated as: + + .. code-block:: text + + lr = learning_rate + + where lr is the learning_rate after warm-up. + + Args: + learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32. + warmup_steps (int): Steps for warm up. + start_lr (float): Initial learning rate of warm up. + end_lr (float): Final learning rate of warm up. + + Returns: + Variable: Warm-up learning rate with the same data type as learning_rate. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + boundaries = [100, 200] + lr_steps = [0.1, 0.01, 0.001] + learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor + #learning_rate = 0.1 #case2, single-value + warmup_steps = 50 + start_lr = 1. / 3. + end_lr = 0.1 + decayed_lr = fluid.layers.linear_lr_warmup(learning_rate, + warmup_steps, start_lr, end_lr) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out, = exe.run(fetch_list=[decayed_lr.name]) + print(out) + # case1: [0.33333334] + # case2: [0.33333334] + """ + dtype = 'float32' + if isinstance(learning_rate, Variable): + dtype = learning_rate.dtype + + linear_step = float(end_lr) - float(start_lr) + with default_main_program()._lr_schedule_guard(): + if in_dygraph_mode(): + lr = LinearWarmup(learning_rate, warmup_steps, start_lr, end_lr) + return lr + else: + lr = paddle.static.create_global_var( + shape=[1], + value=0.0, + dtype=dtype, + persistable=True, + name="learning_rate_warmup", + ) + + global_step = _decay_step_counter() + if not isinstance(learning_rate, Variable): + learning_rate = paddle.tensor.fill_constant( + shape=[1], dtype=dtype, value=float(learning_rate) + ) + lr_val = paddle.static.nn.case( + pred_fn_pairs=[ + ( + global_step < warmup_steps, + lambda: start_lr + + linear_step * (global_step / float(warmup_steps)), + ) + ], + default=lambda: learning_rate, + ) + paddle.assign(lr_val, lr) + return lr diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 1a55745a81b5254fe57644e25ae653174dd60642..d8247cf6561bb8f9b1a4c22dd804b941fd364c5d 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -72,8 +72,6 @@ from .nn.control_flow import Print # noqa: F401 from ..fluid.param_attr import WeightNormParamAttr # noqa: F401 from ..fluid.optimizer import Optimizer # noqa: F401 -from ..fluid.layers import exponential_decay # noqa: F401 -from ..fluid.layers import learning_rate_scheduler # noqa: F401 from .nn.metric import auc # noqa: F401 from .nn.metric import accuracy # noqa: F401 @@ -135,5 +133,4 @@ __all__ = [ # noqa 'create_parameter', 'set_ipu_shard', 'ctr_metric_bundle', - 'exponential_decay', ] diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index f04dc277e4e0da5a8f328d2a84df97c02bfd1731..41ec17f0567a67af0b9b6cd595e2c37bfbbc55b7 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -24,7 +24,7 @@ from paddle.common_ops_import import ( check_type, check_variable_and_dtype, ) -from paddle.fluid import core, layers, unique_name +from paddle.fluid import core, unique_name from paddle.fluid.data_feeder import check_dtype from paddle.fluid.framework import ( Program, @@ -4210,7 +4210,7 @@ class ExponentialMovingAverage: Update Exponential Moving Average. Should only call this method in train program. """ - global_step = layers.autoincreased_step_counter( + global_step = paddle.optimizer.lr.autoincreased_step_counter( counter_name=self._step_counter_name ) param_master_emas = [] diff --git a/test/collective/fleet/parallel_dygraph_se_resnext.py b/test/collective/fleet/parallel_dygraph_se_resnext.py index 05e9088c9c98042ada3d88b3e90f192190849c36..c24e4e7ebef3d3e0c3f4b89eccbc2f862d56395d 100644 --- a/test/collective/fleet/parallel_dygraph_se_resnext.py +++ b/test/collective/fleet/parallel_dygraph_se_resnext.py @@ -67,7 +67,7 @@ def optimizer_setting(params, parameter_list=None): ) else: optimizer = paddle.optimizer.Momentum( - learning_rate=fluid.layers.cosine_decay( + learning_rate=paddle.optimizer.lr.cosine_decay( learning_rate=lr, step_each_epoch=step, epochs=num_epochs ), momentum=momentum_rate, diff --git a/test/legacy_test/dist_se_resnext.py b/test/legacy_test/dist_se_resnext.py index ddc79809e80a027de60f59a2cb55aa6cc80b7fdf..f7b31d315722f91928f6f617648f345a5bf607fe 100644 --- a/test/legacy_test/dist_se_resnext.py +++ b/test/legacy_test/dist_se_resnext.py @@ -248,7 +248,7 @@ class DistSeResneXt2x2(TestDistRunnerBase): else: optimizer = ( paddle.distributed.fleet.meta_optimizers.DGCMomentumOptimizer( - learning_rate=fluid.layers.piecewise_decay( + learning_rate=paddle.optimizer.lr.piecewise_decay( boundaries=bd, values=lr ), momentum=0.9, diff --git a/test/legacy_test/test_dist_transpiler.py b/test/legacy_test/test_dist_transpiler.py index 8f5565ee7b73db2666dc5e1c7d4477f9aabff4e9..b3a2f95aef78cac0286ee70372079b0925a9fd4d 100644 --- a/test/legacy_test/test_dist_transpiler.py +++ b/test/legacy_test/test_dist_transpiler.py @@ -477,7 +477,7 @@ class TestLRDecayConditional(TranspilerTest): cost = paddle.nn.functional.square_error_cost(input=y_predict, label=y) avg_cost = paddle.mean(cost) sgd_optimizer = paddle.optimizer.SGD( - learning_rate=fluid.layers.piecewise_decay( + learning_rate=paddle.optimizer.lr.piecewise_decay( [10000, 20000], [1.0, 0.5, 1.0] ) ) @@ -581,7 +581,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): bd = [1, 10, 20, 30] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] sgd_optimizer = paddle.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( + learning_rate=paddle.optimizer.lr.piecewise_decay( boundaries=bd, values=lr ), momentum=0.9, diff --git a/test/legacy_test/test_imperative_ocr_attention_model.py b/test/legacy_test/test_imperative_ocr_attention_model.py index 30c00600ae7723153b77b477324b4a2afb070bcd..8b07c7652fad909773031b29dad0399f34e2e692 100644 --- a/test/legacy_test/test_imperative_ocr_attention_model.py +++ b/test/legacy_test/test_imperative_ocr_attention_model.py @@ -451,7 +451,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ocr_attention = OCRAttention() if Config.learning_rate_decay == "piecewise_decay": - learning_rate = fluid.layers.piecewise_decay( + learning_rate = paddle.optimizer.lr.piecewise_decay( [50000], [Config.LR, Config.LR * 0.01] ) else: @@ -527,7 +527,7 @@ class TestDygraphOCRAttention(unittest.TestCase): ocr_attention = OCRAttention() if Config.learning_rate_decay == "piecewise_decay": - learning_rate = fluid.layers.piecewise_decay( + learning_rate = paddle.optimizer.lr.piecewise_decay( [50000], [Config.LR, Config.LR * 0.01] ) else: diff --git a/test/legacy_test/test_imperative_resnet.py b/test/legacy_test/test_imperative_resnet.py index 4bb78f64d31258e0396d94c2a9328c6c139591d6..41e270c67958a54268630355646ae4a391a20f47 100644 --- a/test/legacy_test/test_imperative_resnet.py +++ b/test/legacy_test/test_imperative_resnet.py @@ -67,7 +67,7 @@ def optimizer_setting(params, parameter_list=None): # TODO(minqiyang): Add learning rate scheduler support to dygraph mode # optimizer = fluid.optimizer.Momentum( # learning_rate=params["lr"], - # learning_rate=fluid.layers.piecewise_decay( + # learning_rate=paddle.optimizer.lr.piecewise_decay( # boundaries=bd, values=lr), # momentum=0.9, # regularization=paddle.regularizer.L2Decay(1e-4)) diff --git a/test/legacy_test/test_imperative_resnet_sorted_gradient.py b/test/legacy_test/test_imperative_resnet_sorted_gradient.py index 98bdd0c8ccb075692380f100428a91951e6efb5e..ba71a803fb650560248aec7214d901030499a22e 100644 --- a/test/legacy_test/test_imperative_resnet_sorted_gradient.py +++ b/test/legacy_test/test_imperative_resnet_sorted_gradient.py @@ -63,7 +63,7 @@ def optimizer_setting(params, parameter_list=None): # TODO(minqiyang): Add learning rate scheduler support to dygraph mode # optimizer = fluid.optimizer.Momentum( # learning_rate=params["lr"], - # learning_rate=fluid.layers.piecewise_decay( + # learning_rate=paddle.optimizer.lr.piecewise_decay( # boundaries=bd, values=lr), # momentum=0.9, # regularization=paddle.regularizer.L2Decay(1e-4)) diff --git a/test/legacy_test/test_imperative_transformer_sorted_gradient.py b/test/legacy_test/test_imperative_transformer_sorted_gradient.py index 2d724c080cb7ebbd841d2ce700b7234227d4a358..f85986fab3584c84e96ab5503e3fd107d3e7082b 100644 --- a/test/legacy_test/test_imperative_transformer_sorted_gradient.py +++ b/test/legacy_test/test_imperative_transformer_sorted_gradient.py @@ -1137,7 +1137,7 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): is_sparse=is_sparse, ) if sync: - lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( + lr_decay = paddle.optimizer.lr.noam_decay( ModelHyperParams.d_model, TrainTaskConfig.warmup_steps ) with fluid.default_main_program()._lr_schedule_guard(): diff --git a/test/legacy_test/test_learning_rate_scheduler.py b/test/legacy_test/test_learning_rate_scheduler.py index 0ff27eec5ad2067e601095cc93ce6fdc62dccff1..8898fb59b87b155f070452bd5d328297e4d4ae50 100644 --- a/test/legacy_test/test_learning_rate_scheduler.py +++ b/test/legacy_test/test_learning_rate_scheduler.py @@ -20,7 +20,7 @@ import numpy as np import paddle from paddle import fluid -from paddle.fluid import core, framework, layers +from paddle.fluid import core, framework def exponential_decay( @@ -239,7 +239,9 @@ class TestLearningRateDecayDygraph(unittest.TestCase): d_model = 0.01 warmup_steps = 200 learning_rate = 2.0 - lr = fluid.layers.noam_decay(d_model, warmup_steps, learning_rate) + lr = paddle.optimizer.lr.noam_decay( + d_model, warmup_steps, learning_rate + ) for step in range(5): step += 1 right_result = noam_decay( @@ -278,7 +280,7 @@ class TestLearningRateDecayDygraph(unittest.TestCase): np.testing.assert_allclose(t, right_result[i], rtol=1e-05) with self.assertRaises(TypeError): - lr = fluid.layers.linear_lr_warmup( + lr = paddle.optimizer.lr.linear_lr_warmup( learning_rate="fake_lr", warmup_steps=2, start_lr=0.0, @@ -443,39 +445,59 @@ class TestLearningRateDecay(unittest.TestCase): common_kwargs_false["staircase"] = False decay_fns = [ - (exponential_decay, layers.exponential_decay, common_kwargs_true), - (exponential_decay, layers.exponential_decay, common_kwargs_false), - (natural_exp_decay, layers.natural_exp_decay, common_kwargs_true), - (natural_exp_decay, layers.natural_exp_decay, common_kwargs_false), - (inverse_time_decay, layers.inverse_time_decay, common_kwargs_true), + ( + exponential_decay, + paddle.optimizer.lr.exponential_decay, + common_kwargs_true, + ), + ( + exponential_decay, + paddle.optimizer.lr.exponential_decay, + common_kwargs_false, + ), + ( + natural_exp_decay, + paddle.optimizer.lr.natural_exp_decay, + common_kwargs_true, + ), + ( + natural_exp_decay, + paddle.optimizer.lr.natural_exp_decay, + common_kwargs_false, + ), + ( + inverse_time_decay, + paddle.optimizer.lr.inverse_time_decay, + common_kwargs_true, + ), ( inverse_time_decay, - layers.inverse_time_decay, + paddle.optimizer.lr.inverse_time_decay, common_kwargs_false, ), ( polynomial_decay, - layers.polynomial_decay, + paddle.optimizer.lr.polynomial_decay, {"learning_rate": 1.0, "decay_steps": 5, "cycle": True}, ), ( polynomial_decay, - layers.polynomial_decay, + paddle.optimizer.lr.polynomial_decay, {"learning_rate": 1.0, "decay_steps": 5, "cycle": False}, ), ( piecewise_decay, - layers.piecewise_decay, + paddle.optimizer.lr.piecewise_decay, {"boundaries": [3, 6, 9], "values": [0.1, 0.2, 0.3, 0.4]}, ), ( cosine_decay, - layers.cosine_decay, + paddle.optimizer.lr.cosine_decay, {"learning_rate": 0.1, "step_each_epoch": 100, "epochs": 120}, ), ( noam_decay, - layers.noam_decay, + paddle.optimizer.lr.noam_decay, {"d_model": 0.01, "warmup_steps": 200, "learning_rate": 2.0}, ), ] @@ -507,7 +529,7 @@ class TestLinearWamrupLearningRateDecay(unittest.TestCase): end_lr = 0.1 with fluid.program_guard(main_prog, startup_prog): - decayed_lr = layers.linear_lr_warmup( + decayed_lr = paddle.optimizer.lr.linear_lr_warmup( fluid_decay_fn(**kwargs), warmup_steps, start_lr, end_lr ) @@ -548,7 +570,7 @@ class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase): warmup_steps = 10 with fluid.program_guard(main_prog, startup_prog): - decayed_lr = layers.linear_lr_warmup( + decayed_lr = paddle.optimizer.lr.linear_lr_warmup( lr, warmup_steps, start_lr, end_lr )