提交 71843142 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2755 unform learning rate behavior in optimizers

Merge pull request !2755 from wangnan39/uniform_lr_behavior_in_optimizers
...@@ -231,8 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): ...@@ -231,8 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
>>> cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) >>> cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
""" """
validator.check_float_positive('min_lr', min_lr, None) if not isinstance(min_lr, float):
validator.check_float_legal_value('min_lr', min_lr, None) raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('max_lr', max_lr, None) validator.check_float_positive('max_lr', max_lr, None)
validator.check_float_legal_value('max_lr', max_lr, None) validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
...@@ -288,8 +289,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e ...@@ -288,8 +289,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
""" """
validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate, None) if not isinstance(end_learning_rate, float):
validator.check_float_legal_value('end_learning_rate', end_learning_rate, None) raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('power', power, None) validator.check_float_positive('power', power, None)
validator.check_float_legal_value('power', power, None) validator.check_float_legal_value('power', power, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
...@@ -311,11 +313,58 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e ...@@ -311,11 +313,58 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
return lr return lr
def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
r"""
Get learning rate warming up.
For the i-th step, the formula of computing warmup_learning_rate[i] is:
.. math::
warmup\_learning\_rate[i] = learning\_rate * tmp\_epoch / tmp\_warmup\_epoch
Where :math:`tmp\_epoch=min(current\_epoch, warmup\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`
Args:
learning_rate (float): The initial value of learning rate.
warmup_steps (int): The warm up steps of learning rate.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> total_step = 6
>>> step_per_epoch = 2
>>> warmup_epoch = 2
>>> warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch)
[0.0, 0.0, 0.05, 0.05, 0.1, 0.1]
"""
if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
function = lambda x, y: (x, min(x, y))
lr = []
for i in range(total_step):
current_epoch = math.floor(i / step_per_epoch)
warmup_epoch, tmp_epoch = function(warmup_epoch, current_epoch)
lr.append(learning_rate * tmp_epoch/ warmup_epoch)
return lr
__all__ = [ __all__ = [
'piecewise_constant_lr', 'piecewise_constant_lr',
'exponential_decay_lr', 'exponential_decay_lr',
'natural_exp_decay_lr', 'natural_exp_decay_lr',
'inverse_decay_lr', 'inverse_decay_lr',
'cosine_decay_lr', 'cosine_decay_lr',
'polynomial_decay_lr' 'polynomial_decay_lr',
'warmup_lr'
] ]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate schedule."""
import math
from ..common import dtype as mstype
from ..ops import operations as P
from .cell import Cell
from .._checkparam import Validator as validator
from .._checkparam import Rel
class LearningRateSchedule(Cell):
def __init__(self):
super(LearningRateSchedule, self).__init__()
def construct(self, global_step):
raise NotImplementedError
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, cls_name)
validator.check_float_positive('learning_rate', learning_rate, cls_name)
validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
validator.check_float_positive('decay_rate', decay_rate, cls_name)
validator.check_float_legal_value('decay_rate', decay_rate, cls_name)
validator.check_value_type('is_stair', is_stair, [bool], cls_name)
class ExponentialDecayLR(LearningRateSchedule):
r"""
Calculate learning rate base on exponential decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{p}}
Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula
is :math:`p = floor(\frac{current\_step}{decay\_steps})`.
Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
decay_steps (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> decay_steps = 4
>>> global_step = Tenosr(2, mstype.int32)
>>> exponential_decay_lr = ExponentialDecayLR(learning_rate, decay_rate, decay_steps)
>>> exponential_decay_lr(global_step)
"""
def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False):
super(ExponentialDecayLR, self).__init__()
_check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name)
self.learning_rate = learning_rate
self.decay_rate = decay_rate
self.decay_steps = decay_steps
self.is_stair = is_stair
self.pow = P.Pow()
self.cast = P.Cast()
def construct(self, global_step):
p = self.cast(global_step, mstype.float32) / self.decay_steps
if self.is_stair:
p = P.Floor()(p)
return self.learning_rate * self.pow(self.decay_rate, p)
class NaturalExpDecayLR(LearningRateSchedule):
r"""
Calculate learning rate base on natural exponential decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * p}
Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula
is :math:`p = floor(\frac{current\_step}{decay\_steps})`.
Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
decay_steps (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> decay_steps = 4
>>> global_step = Tenosr(2, mstype.int32)
>>> natural_exp_decay_lr = NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True)
>>> natural_exp_decay_lr(global_step)
"""
def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False):
super(NaturalExpDecayLR, self).__init__()
_check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name)
self.learning_rate = learning_rate
self.decay_rate = decay_rate
self.decay_steps = decay_steps
self.is_stair = is_stair
self.math_e = math.e
self.pow = P.Pow()
self.cast = P.Cast()
def construct(self, global_step):
p = self.cast(global_step, mstype.float32)
if self.is_stair:
p = P.FloorDiv()(p, self.decay_steps) * self.decay_steps
return self.learning_rate * self.pow(self.math_e, -self.decay_rate * p)
class InverseDecayLR(LearningRateSchedule):
r"""
Calculate learning rate base on inverse-time decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * p}
Where :math:`p = \frac{current\_step}{decay\_steps}`, if `is_stair` is True, The formula
is :math:`p = floor(\frac{current\_step}{decay\_steps})`.
Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> decay_steps = 4
>>> global_step = Tenosr(2, mstype.int32)
>>> inverse_decay_lr = InverseDecayLR(learning_rate, decay_rate, decay_steps, True)
>>> inverse_decay_lr(global_step)
"""
def __init__(self, learning_rate, decay_rate, decay_steps, is_stair=False):
super(InverseDecayLR, self).__init__()
_check_inputs(learning_rate, decay_rate, decay_steps, is_stair, self.cls_name)
self.learning_rate = learning_rate
self.decay_rate = decay_rate
self.decay_steps = decay_steps
self.is_stair = is_stair
self.cast = P.Cast()
def construct(self, global_step):
p = self.cast(global_step, mstype.float32) / self.decay_steps
if self.is_stair:
p = P.Floor()(p)
return self.learning_rate / (1 + self.decay_rate * p)
class CosineDecayLR(LearningRateSchedule):
r"""
Calculate learning rate base on cosine decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
(1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))
Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
Args:
min_lr (float): The minimum value of learning rate.
max_lr (float): The maximum value of learning rate.
decay_steps (int): A value used to calculate decayed learning rate.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> min_lr = 0.01
>>> max_lr = 0.1
>>> decay_steps = 4
>>> global_step = Tenosr(2, mstype.int32)
>>> cosine_decay_lr = CosineDecayLR(min_lr, max_lr, decay_steps)
>>> cosine_decay_lr(global_steps)
"""
def __init__(self, min_lr, max_lr, decay_steps):
super(CosineDecayLR, self).__init__()
if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_float_positive('max_lr', max_lr, self.cls_name)
validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name)
if min_lr >= max_lr:
raise ValueError('`max_lr` should be greater than `min_lr`.')
self.min_lr = min_lr
self.max_lr = max_lr
self.decay_steps = decay_steps
self.math_pi = math.pi
self.delta = 0.5 * (max_lr - min_lr)
self.cos = P.Cos()
self.min = P.Minimum()
self.cast = P.Cast()
def construct(self, global_step):
p = self.cast(self.min(global_step, self.decay_steps), mstype.float32)
return self.min_lr + self.delta * (1.0 + self.cos(self.math_pi * p / self.decay_steps))
class PolynomialDecayLR(LearningRateSchedule):
r"""
Calculate learning rate base on polynomial decay function.
For the i-th step, the formula of computing decayed_learning_rate[i] is:
.. math::
decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
(1 - tmp\_step / tmp\_decay\_step)^{power} + end\_learning\_rate
Where :math:`tmp\_step=min(global\_step, decay\_step).
If `update_decay_steps` is true, update the value of `tmp_decay_step` every `decay_steps`. The formula
is :math:`tmp\_decay\_step = decay\_step * ceil(global\_step / decay\_steps)`
Args:
learning_rate (float): The initial value of learning rate.
end_learning_rate (float): The end value of learning rate.
decay_steps (int): A value used to calculate decayed learning rate.
power (float): A value used to calculate decayed learning rate. This parameter should be greater than 0.
update_decay_steps (bool): If true, learning rate decay once every `decay_steps` times. Default: False.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> end_learning_rate = 0.01
>>> decay_steps = 4
>>> power = 0.5
>>> global_step = Tenosr(2, mstype.int32)
>>> polynomial_decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
>>> polynomial_decay_lr(global_step)
"""
def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False):
super(PolynomialDecayLR, self).__init__()
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None)
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
self.cls_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name)
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
validator.check_float_positive('power', power, self.cls_name)
validator.check_float_legal_value('power', power, self.cls_name)
self.decay_steps = decay_steps
self.start_learning_rate = learning_rate
self.end_learning_rate = end_learning_rate
self.diff_learning_rate = learning_rate - end_learning_rate
self.power = power
self.update_decay_steps = update_decay_steps
self.pow = P.Pow()
self.ceil = P.Ceil()
self.min = P.Minimum()
self.max = P.Maximum()
def construct(self, global_step):
tmp_global_step = P.Cast()(global_step, mstype.float32)
tmp_decay_step = self.decay_steps
if self.update_decay_steps:
tmp_decay_step = tmp_decay_step * self.max(self.ceil(tmp_global_step / tmp_decay_step), 1)
else:
tmp_global_step = self.min(tmp_global_step, tmp_decay_step)
p = tmp_global_step / tmp_decay_step
lr = self.diff_learning_rate * self.pow(1.0 - p, self.power) + self.end_learning_rate
return lr
class WarmUpLR(LearningRateSchedule):
r"""
Get learning rate warming up.
For the i-th step, the formula of computing warmup_learning_rate[i] is:
.. math::
warmup\_learning\_rate[i] = learning\_rate * tmp\_step / warmup\_steps
Where :math:`tmp\_step=min(global\_step, warmup\_steps).
Args:
learning_rate (float): The initial value of learning rate.
warmup_steps (int): The warm up steps of learning rate.
Inputs:
Tensor. The current step number.
Returns:
Tensor. The learning rate value for the current step.
Examples:
>>> learning_rate = 0.1
>>> warmup_steps = 2
>>> global_step = Tenosr(2, mstype.int32)
>>> warmup_lr = WarmUpLR(learning_rate, warmup_steps)
>>> warmup_lr(global_step)
"""
def __init__(self, learning_rate, warmup_steps):
super(WarmUpLR, self).__init__()
if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GT, self.cls_name)
self.warmup_steps = warmup_steps
self.learning_rate = learning_rate
self.min = P.Minimum()
self.cast = P.Cast()
def construct(self, global_step):
warmup_percent = self.cast(self.min(global_step, self.warmup_steps), mstype.float32)/ self.warmup_steps
return self.learning_rate * warmup_percent
__all__ = [
'ExponentialDecayLR',
'NaturalExpDecayLR',
'InverseDecayLR',
'CosineDecayLR',
'PolynomialDecayLR',
'WarmUpLR'
]
...@@ -20,7 +20,7 @@ The optimizer is used to calculate and update the gradients. ...@@ -20,7 +20,7 @@ The optimizer is used to calculate and update the gradients.
""" """
from .optimizer import Optimizer from .optimizer import Optimizer
from .momentum import Momentum from .momentum import Momentum
from .adam import Adam, PSAdam, AdamWeightDecay, AdamWeightDecayDynamicLR from .adam import Adam, PSAdam, AdamWeightDecay
from .lamb import Lamb from .lamb import Lamb
from .sgd import SGD from .sgd import SGD
from .lars import LARS from .lars import LARS
...@@ -30,4 +30,4 @@ from .proximal_ada_grad import ProximalAdagrad ...@@ -30,4 +30,4 @@ from .proximal_ada_grad import ProximalAdagrad
from .lazyadam import LazyAdam from .lazyadam import LazyAdam
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam', __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam',
'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad']
此差异已折叠。
...@@ -24,9 +24,9 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") ...@@ -24,9 +24,9 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
_ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt")
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "IndexedSlices", "Tensor", @_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "IndexedSlices", "Tensor",
"Tensor", "Bool") "Tensor", "Bool")
def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, def _tensor_run_opt_with_sparse(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment,
ps_parameter): ps_parameter):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
...@@ -43,9 +43,9 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, ...@@ -43,9 +43,9 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power,
return success return success
@_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", @_ftrl_opt.register("Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool") "Tensor", "Bool")
def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment, ps_parameter): def _tensor_run_opt(opt, spars_opt, l1, l2, lr_power, learning_rate, linear, gradient, weight, moment, ps_parameter):
"""Apply ftrl optimizer to the weight parameter.""" """Apply ftrl optimizer to the weight parameter."""
success = True success = True
if ps_parameter: if ps_parameter:
...@@ -83,7 +83,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2, ...@@ -83,7 +83,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2,
return success return success
def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None):
"""Check param.""" """Check param."""
validator.check_value_type("initial_accum", initial_accum, [float], prim_name) validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
...@@ -99,9 +99,6 @@ def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, ...@@ -99,9 +99,6 @@ def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0,
validator.check_value_type("use_locking", use_locking, [bool], prim_name) validator.check_value_type("use_locking", use_locking, [bool], prim_name)
validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
class FTRL(Optimizer): class FTRL(Optimizer):
""" """
...@@ -113,15 +110,34 @@ class FTRL(Optimizer): ...@@ -113,15 +110,34 @@ class FTRL(Optimizer):
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document. <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note: Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on all of the parameters.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU.
behavior is currently performed on the CPU.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Using different learning rate by separating parameters is currently not supported.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (float): The learning rate value, should be positive. Default: 0.001. learning_rate (float): The learning rate value, should be zero or positive, dynamic learning rate is currently
not supported. Default: 0.001.
lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less
than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5. than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5.
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
...@@ -139,23 +155,36 @@ class FTRL(Optimizer): ...@@ -139,23 +155,36 @@ class FTRL(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.FTRL(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params},
>>> {'order_params': net.trainable_params()}]
>>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use weight decay of 0.01.
>>> # The no_conv_params's parameters will use default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.FTRL(net.trainable_params()) >>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
""" """
def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0): use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(FTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) super(FTRL, self).__init__(learning_rate, params, weight_decay, loss_scale=loss_scale)
if self.is_group: if self.dynamic_lr or self.is_group_lr:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") raise ValueError('Dynamic learning rate or group learning rate is currently not supported.')
_check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros') self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1 self.l1 = l1
self.l2 = l2 self.l2 = l2
self.lr_power = lr_power self.lr_power = lr_power
self.weight_decay = weight_decay if not self.is_group:
self.decay_tf = tuple((lambda: True)() for x in self.parameters) self.decay_flags = tuple((lambda: True)() for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.ApplyFtrl(use_locking=use_locking) self.opt = P.ApplyFtrl(use_locking=use_locking)
self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
...@@ -164,12 +193,11 @@ class FTRL(Optimizer): ...@@ -164,12 +193,11 @@ class FTRL(Optimizer):
params = self.parameters params = self.parameters
moments = self.moments moments = self.moments
linear = self.linear linear = self.linear
lr = self.learning_rate grads = self.decay_weight(grads)
if self.weight_decay > 0.0:
grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.scale_grad(grads) grads = self.scale_grad(grads)
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), lr = self.get_lr()
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self.l1, self.l2, self.lr_power, lr),
linear, grads, params, moments, self.ps_parameters) linear, grads, params, moments, self.ps_parameters)
return success return success
...@@ -180,7 +208,7 @@ class PSFTRL(Optimizer): ...@@ -180,7 +208,7 @@ class PSFTRL(Optimizer):
super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale)
if self.is_group: if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros') self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1 self.l1 = l1
......
...@@ -32,10 +32,9 @@ num_one = Tensor(np.ones([1]), mstype.float32) ...@@ -32,10 +32,9 @@ num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt") _lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool") "Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
gradient, decay_flag, optim_filter):
""" """
Update parameters. Update parameters.
...@@ -44,7 +43,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para ...@@ -44,7 +43,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate. lr (Tensor): Learning rate.
weight_decay_tensor (Tensor): Weight decay. Should be in range [0.0, 1.0]. weight_decay (Number): Weight decay. Should be in range [0.0, 1.0].
global_step (Tensor): Global step. global_step (Tensor): Global step.
param (Tensor): Parameters. param (Tensor): Parameters.
m (Tensor): m value of parameters. m (Tensor): m value of parameters.
...@@ -87,7 +86,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para ...@@ -87,7 +86,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
w_norm = op_norm(param_fp32) w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32) g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
zeros = F.zeros_like(w_norm) zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select( trust_ratio = op_select(
...@@ -99,7 +98,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para ...@@ -99,7 +98,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
update = next_mm / (op_sqrt(next_vv) + eps) update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag: if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32) update = update + op_mul(weight_decay, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update) update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
...@@ -116,10 +115,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para ...@@ -116,10 +115,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
@lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool") "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag):
global_step, param, m, v, gradient, decay_flag):
""" """
Update parameters. Update parameters.
...@@ -128,7 +126,7 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, ...@@ -128,7 +126,7 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor,
beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate. lr (Tensor): Learning rate.
weight_decay_tensor (Tensor): Weight decay. Should be in range [0.0, 1.0]. weight_decay (Number): Weight decay. Should be in range [0.0, 1.0].
global_step (Tensor): Global step. global_step (Tensor): Global step.
param (Tensor): Parameters. param (Tensor): Parameters.
m (Tensor): m value of parameters. m (Tensor): m value of parameters.
...@@ -157,11 +155,10 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, ...@@ -157,11 +155,10 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor,
i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex) i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex)
i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex) i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex)
i1 = op_square(gradient_fp32) i1 = op_square(gradient_fp32)
add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9, beta2, x1, weight_decay, eps)
i9, beta2, x1, weight_decay_tensor, eps)
if decay_flag: if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32) update = update + op_mul(weight_decay, param_fp32)
w_norm = op_norm(param_fp32) w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32) g_norm = op_norm(gradient_fp32)
...@@ -171,38 +168,18 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor, ...@@ -171,38 +168,18 @@ def _update_run_op_graph_kernel(beta1, beta2, eps, lr, weight_decay_tensor,
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0) tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0)
next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, param, zeros, ones, tens)
param, zeros, ones, tens)
next_v = F.control_depend(add3, next_param) next_v = F.control_depend(add3, next_param)
return next_v return next_v
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, def _check_param_value(beta1, beta2, eps, prim_name):
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name)
validator.check_number_range("start_learning_rate rate", start_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
prim_name)
validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type( validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
"weight_dacay", weight_decay, [float], prim_name) validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range( validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
"beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range(
"beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range(
"eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range(
"weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
class Lamb(Optimizer): class Lamb(Optimizer):
...@@ -213,16 +190,37 @@ class Lamb(Optimizer): ...@@ -213,16 +190,37 @@ class Lamb(Optimizer):
optimization technique. Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76 optimization technique. Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76
MINUTES <https://arxiv.org/abs/1904.00962>`_. MINUTES <https://arxiv.org/abs/1904.00962>`_.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be class mindspore.Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
decay_steps (int): The steps of the lr decay. Should be equal to or greater than 1. "lr", "weight_decay" and "order_params" are the keys can be parsed.
warmup_steps (int): The steps of lr warm up. Should be equal to or greater than 0. Default: 0.
start_learning_rate (float): A floating point value for the learning rate. Should be equal to - params: Required. The value should be a list of `Parameter`.
or greater than 0. Default: 0.1.
end_learning_rate (float): A floating point value for the end learning rate. Should be equal to - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
or greater than 0. Default: 0.0001. If not, the `learning_rate` in the API will be used.
power (float): The power of the polynomial. It must be positive. Default: 1.0.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0). Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
...@@ -241,90 +239,84 @@ class Lamb(Optimizer): ...@@ -241,90 +239,84 @@ class Lamb(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Lamb(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR()
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params, 'lr': poly_decay_lr},
>>> {'order_params': net.trainable_params(0.01, 0.0001, 10, 0.5)}]
>>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default
>>> # weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Lamb(params=net.trainable_params(), decay_steps=10) >>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
params, super(Lamb, self).__init__(learning_rate, params, weight_decay)
decay_steps, _check_param_value(beta1, beta2, eps, self.cls_name)
warmup_steps=0,
start_learning_rate=0.1,
end_learning_rate=0.0001,
power=1.0,
beta1=0.9,
beta2=0.999,
eps=1e-6,
weight_decay=0.0,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
super(Lamb, self).__init__(0.0, params)
if self.is_group:
raise RuntimeError(
f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate,
power, beta1, beta2, eps, weight_decay, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations # turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step")
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
self.start_learning_rate = Tensor(
np.array([start_learning_rate]).astype(np.float32))
self.end_learning_rate = Tensor(
np.array([end_learning_rate]).astype(np.float32))
self.diff_learning_rate = Tensor(
np.array([start_learning_rate - end_learning_rate]).astype(np.float32))
self.power = power
self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32)) self.eps = Tensor(np.array([eps]).astype(np.float32))
self.weight_decay_tensor = Tensor(
np.array([weight_decay]).astype(np.float32))
self.params = self.parameters self.params = self.parameters
self.moments1 = self.params.clone(prefix="lamb_m", init='zeros') self.moments1 = self.params.clone(prefix="lamb_m", init='zeros')
self.moments2 = self.params.clone(prefix="lamb_v", init='zeros') self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
self.decay_flag = tuple(decay_filter(x) for x in self.params)
if not self.dynamic_lr:
self.global_step = Parameter(initializer(0, [1]), name='global_step')
self.assignadd = P.AssignAdd()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.min = P.Minimum()
self.pow = P.Pow()
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
self.enable_graph_kernel = context.get_context("enable_graph_kernel") self.enable_graph_kernel = context.get_context("enable_graph_kernel")
def construct(self, gradients): def construct(self, gradients):
step = self.min(self.global_step, self.decay_steps) lr = self.get_lr()
p = step / self.decay_steps
lr = self.diff_learning_rate * \
self.pow(self.one - p, self.power) + self.end_learning_rate
if self.warmup_flag:
warmup_percent = self.global_step / self.warmup_steps
warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(
self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
if self.enable_graph_kernel: if self.enable_graph_kernel:
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, if self.is_group:
self.beta1, self.beta2, self.eps, lr, if self.is_group_lr:
self.weight_decay_tensor, self.global_step), optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps,
self.params, self.moments1, self.moments2, gradients, self.decay_flag) self.global_step),
lr, self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags)
else:
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps,
self.global_step, lr),
self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags)
else:
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps,
self.global_step, lr, self.weight_decay),
self.params, self.moments1, self.moments2, gradients, self.decay_flags)
else: else:
optim_result = self.hyper_map(F.partial(_lamb_opt, if self.is_group:
self.beta1, self.beta2, self.eps, lr, if self.is_group_lr:
self.weight_decay_tensor, self.global_step), optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
self.params, self.moments1, self.moments2, gradients, self.global_step),
self.decay_flag, self.optim_filter) lr, self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr),
self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr, self.weight_decay),
self.params, self.moments1, self.moments2, gradients,
self.decay_flags, self.optim_filter)
if self.use_parallel: if self.use_parallel:
optim_result = self.broadcast_params(optim_result) optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one if not self.dynamic_lr:
F.control_depend(lr, added_global_step) F.control_depend(lr, self.assignadd(self.global_step, 1))
self.global_step = added_global_step
return optim_result return optim_result
...@@ -38,14 +38,14 @@ def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_f ...@@ -38,14 +38,14 @@ def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_f
return gradient return gradient
def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name): def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name):
validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) validator.check_value_type("optimizer", optimizer, Optimizer, prim_name)
if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name:
raise TypeError("LARS can not be used with ", optimizer.cls_name)
validator.check_value_type("epsilon", epsilon, [float], prim_name) validator.check_value_type("epsilon", epsilon, [float], prim_name)
validator.check_value_type("coefficient", coefficient, [float], prim_name) validator.check_value_type("coefficient", coefficient, [float], prim_name)
validator.check_value_type("use_clip", use_clip, [bool], prim_name) validator.check_value_type("use_clip", use_clip, [bool], prim_name)
class LARS(Optimizer): class LARS(Optimizer):
""" """
Implements the LARS algorithm with LARSUpdate Operator. Implements the LARS algorithm with LARSUpdate Operator.
...@@ -81,45 +81,71 @@ class LARS(Optimizer): ...@@ -81,45 +81,71 @@ class LARS(Optimizer):
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
_check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name)
self.opt = optimizer self.opt = optimizer
self.parameters = optimizer.parameters
self.use_clip = use_clip
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
self.is_group = optimizer.is_group
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.decay_flags = optimizer.decay_flags
self.reciprocal_scale = optimizer.reciprocal_scale
self.hyper_map = C.HyperMap()
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.cast = P.Cast() self.cast = P.Cast()
self.parameters = optimizer.parameters
if use_clip is True: if use_clip:
self.learning_rate = optimizer.learning_rate self.is_group_lr = optimizer.is_group_lr
self.dynamic_lr = optimizer.dynamic_lr self.dynamic_lr = optimizer.dynamic_lr
self.gather = optimizer.gather self.origin_learning_rate = optimizer.learning_rate
self.assignadd = optimizer.assignadd
self.global_step = optimizer.global_step self.global_step = optimizer.global_step
else: if self.is_group_lr and self.dynamic_lr:
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") raise ValueError('Grouped dynamic learning rate is currently not supported for the inputs optimizer ' \
self.reciprocal_scale = optimizer.reciprocal_scale 'of lars.')
optimizer.reciprocal_scale = 1.0
self.is_group = optimizer.is_group
if self.is_group: if self.is_group:
self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay))
optimizer.weight_decay = tuple(map(lambda x: 0.0, optimizer.weight_decay))
else: else:
self.weight_decay = optimizer.weight_decay / optimizer.loss_scale self.weight_decay = optimizer.weight_decay / optimizer.loss_scale
optimizer.weight_decay = 0.0
optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags))
optimizer.reciprocal_scale = 1.0
optimizer.exec_weight_decay = False optimizer.exec_weight_decay = False
optimizer.weight_decay = 0.0
self.decay_flags = optimizer.decay_flags def _get_lr(self):
self.lars_flag = tuple(lars_filter(x) for x in self.parameters) """Get the learning rate of current step."""
self.hyper_map = C.HyperMap() lr = self.origin_learning_rate
if self.dynamic_lr:
if self.is_group_lr:
lr = ()
for learning_rate in self.origin_learning_rate:
current_dynamic_lr = learning_rate(self.global_step)
lr += (current_dynamic_lr,)
else:
lr = self.origin_learning_rate(self.global_step)
return lr
def construct(self, gradients): def construct(self, gradients):
params = self.parameters params = self.parameters
if self.dynamic_lr: if self.use_clip:
lr = self.gather(self.learning_rate, self.global_step, 0) lr = self._get_lr()
F.control_depend(lr, self.assignadd(self.global_step, 1))
else: else:
lr = self.learning_rate lr = self.learning_rate
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
if self.is_group: if self.is_group:
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, if self.is_group_lr:
gradients, params, self.decay_flags, self.lars_flag) gradients = self.hyper_map(F.partial(_lars_opt, self.lars), lr, self.weight_decay,
gradients, params, self.decay_flags, self.lars_flag)
else:
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay,
gradients, params, self.decay_flags, self.lars_flag)
else: else:
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay),
gradients, params, self.decay_flags, self.lars_flag) gradients, params, self.decay_flags, self.lars_flag)
success = self.opt(grad_t) success = self.opt(gradients)
return success return success
...@@ -84,12 +84,11 @@ class LazyAdam(Optimizer): ...@@ -84,12 +84,11 @@ class LazyAdam(Optimizer):
:math:`\epsilon` represents `eps`. :math:`\epsilon` represents `eps`.
Note: Note:
The LazyAdam optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse behavior, to be notice, is not equivalent to the The sparse behavior, to be notice, is not equivalent to the
...@@ -113,13 +112,14 @@ class LazyAdam(Optimizer): ...@@ -113,13 +112,14 @@ class LazyAdam(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters. in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
Iterable or a Tensor and the dims of the Tensor is 1, When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
use dynamic learning rate, then the i-th step will the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
take the i-th value as the learning rate. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
When the learning_rate is float or learning_rate is a Tensor according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
but the dims of the Tensor is 0, use fixed learning rate. dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
Other cases are not supported. Default: 1e-3. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default: beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
0.9. 0.9.
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default: beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
...@@ -153,9 +153,9 @@ class LazyAdam(Optimizer): ...@@ -153,9 +153,9 @@ class LazyAdam(Optimizer):
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params, 'lr': 0.01}, >>> {'params': no_conv_params, 'lr': 0.01},
>>> {'order_params': net.trainable_params()}] >>> {'order_params': net.trainable_params()}]
>>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) >>> opt = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>> >>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
......
...@@ -47,12 +47,9 @@ class Momentum(Optimizer): ...@@ -47,12 +47,9 @@ class Momentum(Optimizer):
Refer to the paper on the importance of initialization and momentum in deep learning for more details. Refer to the paper on the importance of initialization and momentum in deep learning for more details.
Note: Note:
The Momentum optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported. To improve parameter groups performance, the customized order of parameters can be supported.
...@@ -73,14 +70,13 @@ class Momentum(Optimizer): ...@@ -73,14 +70,13 @@ class Momentum(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters. in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
Iterable or a Tensor and the dims of the Tensor is 1, When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
use dynamic learning rate, then the i-th step will the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
take the i-th value as the learning rate. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
When the learning_rate is float or learning_rate is a according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
Tensor but the dims of the Tensor is 0, use fixed learning dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
rate. Other cases are not supported. It should be equal to equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
or greater than 0.0.
momentum (float): Hyperparameter of type float, means momentum for the moving average. momentum (float): Hyperparameter of type float, means momentum for the moving average.
It should be at least 0.0. It should be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0. weight_decay (int, float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
......
此差异已折叠。
...@@ -32,7 +32,7 @@ def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient ...@@ -32,7 +32,7 @@ def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): def _tensor_run_opt(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
"""Apply proximal_ada_grad optimizer to the weight parameter.""" """Apply proximal_ada_grad optimizer to the weight parameter."""
success = True success = True
success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient)) success = F.depend(success, opt(weight, accum, learning_rate, l1, l2, gradient))
...@@ -59,15 +59,42 @@ class ProximalAdagrad(Optimizer): ...@@ -59,15 +59,42 @@ class ProximalAdagrad(Optimizer):
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_. <http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
Note: Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU. behavior is currently performed on the CPU.
Args: Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params` params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
should be Parameter. the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value should be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
learning_rate (float): The learning rate value, must be greater than or equal to zero. Default: 0.001. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 0.001.
l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0. l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
use_locking (bool): If True use locks for update operation. Default: False. use_locking (bool): If True use locks for update operation. Default: False.
...@@ -83,21 +110,31 @@ class ProximalAdagrad(Optimizer): ...@@ -83,21 +110,31 @@ class ProximalAdagrad(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.ProximalAdagrad(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
>>> {'params': no_conv_params, 'lr': 0.01},
>>> {'order_params': net.trainable_params()}]
>>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.ProximalAdagrad(net.trainable_params()) >>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
""" """
def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0, def __init__(self, params, accum=0.1, learning_rate=0.001, l1=0.0, l2=0.0,
use_locking=False, loss_scale=1.0, weight_decay=0.0): use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale) super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(accum, l1, l2, use_locking, self.cls_name) _check_param_value(accum, l1, l2, use_locking, self.cls_name)
self.accum = self.parameters.clone(prefix="accum", init=accum) self.accum = self.parameters.clone(prefix="accum", init=accum)
self.l1 = Tensor(l1, mstype.float32) self.l1 = Tensor(l1, mstype.float32)
self.l2 = Tensor(l2, mstype.float32) self.l2 = Tensor(l2, mstype.float32)
self.weight_decay = weight_decay
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking) self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking) self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking)
...@@ -107,7 +144,11 @@ class ProximalAdagrad(Optimizer): ...@@ -107,7 +144,11 @@ class ProximalAdagrad(Optimizer):
accum = self.accum accum = self.accum
grads = self.decay_weight(grads) grads = self.decay_weight(grads)
grads = self.scale_grad(grads) grads = self.scale_grad(grads)
lr = self.learning_rate lr = self.get_lr()
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2), if self.is_group_lr:
grads, params, accum) success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
grads, params, accum)
else:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr),
grads, params, accum)
return success return success
...@@ -44,12 +44,9 @@ class RMSProp(Optimizer): ...@@ -44,12 +44,9 @@ class RMSProp(Optimizer):
Implements Root Mean Squared Propagation (RMSProp) algorithm. Implements Root Mean Squared Propagation (RMSProp) algorithm.
Note: Note:
The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported. To improve parameter groups performance, the customized order of parameters can be supported.
...@@ -109,13 +106,14 @@ class RMSProp(Optimizer): ...@@ -109,13 +106,14 @@ class RMSProp(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters. in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
Iterable or a Tensor and the dims of the Tensor is 1, When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
use dynamic learning rate, then the i-th step will the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
take the i-th value as the learning rate. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
When the learning_rate is float or learning_rate is a Tensor according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
but the dims of the Tensor is 0, use fixed learning rate. dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
Other cases are not supported. Default: 0.1. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 0.1.
decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9.
momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or
greater than 0. Default: 0.0. greater than 0. Default: 0.0.
......
...@@ -40,14 +40,11 @@ class SGD(Optimizer): ...@@ -40,14 +40,11 @@ class SGD(Optimizer):
momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_. momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
Note: Note:
The SGD optimizer supports separating parameter groups. Different parameter groups can set different
`learning_rate` and `weight_decay`.
When separating parameter groups, the weight decay in each group will be applied on the parameters if the When separating parameter groups, the weight decay in each group will be applied on the parameters if the
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported. To improve parameter groups performance, the customized order of parameters can be supported.
Args: Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
...@@ -66,14 +63,14 @@ class SGD(Optimizer): ...@@ -66,14 +63,14 @@ class SGD(Optimizer):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters. in the value of 'order_params' should be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning rate.
Iterable or a Tensor and the dims of the Tensor is 1, When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
use dynamic learning rate, then the i-th step will the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
take the i-th value as the learning rate. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
When the learning_rate is float or learning_rate is a Tensor according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
but the dims of the Tensor is 0, use fixed learning rate. dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
Other cases are not supported. It should be equal to or equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
greater than 0. Default: 0.1. Default: 0.1.
momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0. momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0.
dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0. dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0. weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0.
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# ============================================================================ # ============================================================================
"""Learning scheduler.""" """Learning scheduler."""
from math import ceil from math import ceil
import numpy as np import numpy as np
import mindspore.nn.learning_rate_schedule as lr_schedules
def square_root_schedule(lr, update_num, decay_start_step, def square_root_schedule(lr, update_num, decay_start_step,
warmup_steps=2000, warmup_steps=2000,
...@@ -105,3 +106,35 @@ def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup ...@@ -105,3 +106,35 @@ def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
return lrs return lrs
class BertLearningRate(lr_schedules.LearningRateSchedule):
"""
Implements of warmup-polydecay learning rate scheduler.
Args:
learning_rate (float): The initial value of learning rate.
end_learning_rate (float): The end value of learning rate.
warmup_steps (int): The warm up steps of learning rate.
decay_steps (int): A value used to calculate decayed learning rate.
power (float): A value used to calculate decayed learning rate.
Returns:
Tensor. The learning rate value for the current step.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
...@@ -37,7 +37,7 @@ from src.transformer.infer_mass import infer ...@@ -37,7 +37,7 @@ from src.transformer.infer_mass import infer
from src.utils import LossCallBack from src.utils import LossCallBack
from src.utils import one_weight, zero_weight, weight_variable from src.utils import one_weight, zero_weight, weight_variable
from src.utils import square_root_schedule from src.utils import square_root_schedule
from src.utils.lr_scheduler import polynomial_decay_scheduler from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate
parser = argparse.ArgumentParser(description='MASS train entry point.') parser = argparse.ArgumentParser(description='MASS train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.") parser.add_argument("--config", type=str, required=True, help="model config json file path.")
...@@ -178,10 +178,16 @@ def _build_training_pipeline(config: TransformerConfig, ...@@ -178,10 +178,16 @@ def _build_training_pipeline(config: TransformerConfig,
if config.optimizer.lower() == "adam": if config.optimizer.lower() == "adam":
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98) optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)
elif config.optimizer.lower() == "lamb": elif config.optimizer.lower() == "lamb":
optimizer = Lamb(net_with_loss.trainable_params(), decay_steps=12000, lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
start_learning_rate=config.lr, end_learning_rate=config.min_lr, power=10.0, warmup_steps=config.warmup_steps)
power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01, decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
eps=1e-6) net_with_loss.trainable_params()))
other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
net_with_loss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
optimizer = Lamb(group_params, lr, eps=1e-6)
elif config.optimizer.lower() == "momentum": elif config.optimizer.lower() == "momentum":
optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9) optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9)
else: else:
......
...@@ -147,7 +147,7 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): ...@@ -147,7 +147,7 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation):
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
Parameters for optimizer: Parameters for optimizer:
AdamWeightDecayDynamicLR: AdamWeightDecay:
decay_steps steps of the learning rate decay: N decay_steps steps of the learning rate decay: N
learning_rate value of learning rate: Q learning_rate value of learning rate: Q
end_learning_rate value of end learning rate: Q, must be positive end_learning_rate value of end learning rate: Q, must be positive
......
...@@ -23,12 +23,12 @@ from src.bert_for_finetune import BertFinetuneCell, BertCLS ...@@ -23,12 +23,12 @@ from src.bert_for_finetune import BertFinetuneCell, BertCLS
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_classification_dataset from src.dataset import create_classification_dataset
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
from src.utils import make_directory, LossCallBack, LoadNewestCkpt from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
...@@ -42,27 +42,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin ...@@ -42,27 +42,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
raise ValueError("Pretrain model missed, finetune task must load pretrain model!") raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size() steps_per_epoch = dataset.get_dataset_size()
# optimizer # optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': if optimizer_cfg.optimizer == 'AdamWeightDecay':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
decay_steps=steps_per_epoch * epoch_num, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.AdamWeightDecayDynamicLR.power, power=optimizer_cfg.AdamWeightDecay.power)
warmup_steps=int(steps_per_epoch * epoch_num * 0.1), params = net_with_loss.trainable_params()
weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
elif optimizer_cfg.optimizer == 'Lamb': elif optimizer_cfg.optimizer == 'Lamb':
optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, decay_steps=steps_per_epoch * epoch_num,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1), power=optimizer_cfg.Lamb.power)
decay_filter=optimizer_cfg.Lamb.decay_filter) optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
elif optimizer_cfg.optimizer == 'Momentum': elif optimizer_cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
momentum=optimizer_cfg.Momentum.momentum) momentum=optimizer_cfg.Momentum.momentum)
else: else:
raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network # load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
......
...@@ -23,13 +23,13 @@ import argparse ...@@ -23,13 +23,13 @@ import argparse
from src.bert_for_finetune import BertFinetuneCell, BertNER from src.bert_for_finetune import BertFinetuneCell, BertNER
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.dataset import create_ner_dataset from src.dataset import create_ner_dataset
from src.utils import make_directory, LossCallBack, LoadNewestCkpt from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
...@@ -44,27 +44,30 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin ...@@ -44,27 +44,30 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
raise ValueError("Pretrain model missed, finetune task must load pretrain model!") raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size() steps_per_epoch = dataset.get_dataset_size()
# optimizer # optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': if optimizer_cfg.optimizer == 'AdamWeightDecay':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
decay_steps=steps_per_epoch * epoch_num, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.AdamWeightDecayDynamicLR.power, power=optimizer_cfg.AdamWeightDecay.power)
warmup_steps=int(steps_per_epoch * epoch_num * 0.1), params = network.trainable_params()
weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
elif optimizer_cfg.optimizer == 'Lamb': elif optimizer_cfg.optimizer == 'Lamb':
optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, decay_steps=steps_per_epoch * epoch_num,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1), power=optimizer_cfg.Lamb.power)
decay_filter=optimizer_cfg.Lamb.decay_filter) optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
elif optimizer_cfg.optimizer == 'Momentum': elif optimizer_cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
momentum=optimizer_cfg.Momentum.momentum) momentum=optimizer_cfg.Momentum.momentum)
else: else:
raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network # load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
......
...@@ -28,12 +28,12 @@ from mindspore.train.parallel_utils import ParallelMode ...@@ -28,12 +28,12 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore import log as logger from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack from src.utils import LossCallBack, BertLearningRate
_current_dir = os.path.dirname(os.path.realpath(__file__)) _current_dir = os.path.dirname(os.path.realpath(__file__))
...@@ -109,24 +109,35 @@ def run_pretrain(): ...@@ -109,24 +109,35 @@ def run_pretrain():
netwithloss = BertNetworkWithLoss(bert_net_cfg, True) netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
if cfg.optimizer == 'Lamb': if cfg.optimizer == 'Lamb':
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, warmup_steps=cfg.Lamb.warmup_steps,
eps=cfg.Lamb.eps) decay_steps=ds.get_dataset_size() * new_repeat_count,
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum': elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum) momentum=cfg.Momentum.momentum)
elif cfg.optimizer == 'AdamWeightDecayDynamicLR': elif cfg.optimizer == 'AdamWeightDecay':
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
decay_steps=ds.get_dataset_size() * new_repeat_count, end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, warmup_steps=cfg.AdamWeightDecay.warmup_steps,
end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, decay_steps=ds.get_dataset_size() * new_repeat_count,
power=cfg.AdamWeightDecayDynamicLR.power, power=cfg.AdamWeightDecay.power)
weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, params = net_with_loss.trainable_params()
eps=cfg.AdamWeightDecayDynamicLR.eps, decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
warmup_steps=cfg.AdamWeightDecayDynamicLR.warmup_steps) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else: else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
format(cfg.optimizer)) format(cfg.optimizer))
callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()] callback = [TimeMonitor(ds.get_dataset_size()), LossCallBack()]
if args_opt.enable_save_ckpt == "true": if args_opt.enable_save_ckpt == "true":
......
...@@ -25,12 +25,12 @@ from src.dataset import create_squad_dataset ...@@ -25,12 +25,12 @@ from src.dataset import create_squad_dataset
from src import tokenization from src import tokenization
from src.create_squad_data import read_squad_examples, convert_examples_to_features from src.create_squad_data import read_squad_examples, convert_examples_to_features
from src.run_squad import write_predictions from src.run_squad import write_predictions
from src.utils import make_directory, LossCallBack, LoadNewestCkpt from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum from mindspore.nn.optim import AdamWeightDecay, Lamb, Momentum
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
...@@ -44,27 +44,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin ...@@ -44,27 +44,31 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
raise ValueError("Pretrain model missed, finetune task must load pretrain model!") raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size() steps_per_epoch = dataset.get_dataset_size()
# optimizer # optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': if optimizer_cfg.optimizer == 'AdamWeightDecay':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.AdamWeightDecay.learning_rate,
decay_steps=steps_per_epoch * epoch_num, end_learning_rate=optimizer_cfg.AdamWeightDecay.end_learning_rate,
learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, decay_steps=steps_per_epoch * epoch_num,
power=optimizer_cfg.AdamWeightDecayDynamicLR.power, power=optimizer_cfg.AdamWeightDecay.power)
warmup_steps=int(steps_per_epoch * epoch_num * 0.1), params = network.trainable_params()
weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
elif optimizer_cfg.optimizer == 'Lamb': elif optimizer_cfg.optimizer == 'Lamb':
optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, lr_schedule = BertLearningRate(learning_rate=optimizer_cfg.Lamb.learning_rate,
start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, decay_steps=steps_per_epoch * epoch_num,
warmup_steps=int(steps_per_epoch * epoch_num * 0.1), power=optimizer_cfg.Lamb.power)
decay_filter=optimizer_cfg.Lamb.decay_filter) optimizer = Lamb(network.trainable_params(), learning_rate=lr_schedule)
elif optimizer_cfg.optimizer == 'Momentum': elif optimizer_cfg.optimizer == 'Momentum':
optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
momentum=optimizer_cfg.Momentum.momentum) momentum=optimizer_cfg.Momentum.momentum)
else: else:
raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
# load checkpoint into network # load checkpoint into network
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
......
...@@ -24,20 +24,22 @@ cfg = edict({ ...@@ -24,20 +24,22 @@ cfg = edict({
'scale_factor': 2, 'scale_factor': 2,
'scale_window': 1000, 'scale_window': 1000,
'optimizer': 'Lamb', 'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({ 'AdamWeightDecay': edict({
'learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 1e-10, 'end_learning_rate': 1e-10,
'power': 5.0, 'power': 5.0,
'weight_decay': 1e-5, 'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6, 'eps': 1e-6,
'warmup_steps': 10000, 'warmup_steps': 10000,
}), }),
'Lamb': edict({ 'Lamb': edict({
'start_learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 1e-10, 'end_learning_rate': 1e-10,
'power': 10.0, 'power': 10.0,
'warmup_steps': 10000, 'warmup_steps': 10000,
'weight_decay': 0.01, 'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6, 'eps': 1e-6,
}), }),
'Momentum': edict({ 'Momentum': edict({
......
...@@ -23,19 +23,20 @@ from .bert_model import BertConfig ...@@ -23,19 +23,20 @@ from .bert_model import BertConfig
optimizer_cfg = edict({ optimizer_cfg = edict({
'optimizer': 'Lamb', 'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({ 'AdamWeightDecay': edict({
'learning_rate': 2e-5, 'learning_rate': 2e-5,
'end_learning_rate': 1e-7, 'end_learning_rate': 1e-7,
'power': 1.0, 'power': 1.0,
'weight_decay': 1e-5, 'weight_decay': 1e-5,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6, 'eps': 1e-6,
}), }),
'Lamb': edict({ 'Lamb': edict({
'start_learning_rate': 2e-5, 'learning_rate': 2e-5,
'end_learning_rate': 1e-7, 'end_learning_rate': 1e-7,
'power': 1.0, 'power': 1.0,
'weight_decay': 0.01, 'weight_decay': 0.01,
'decay_filter': lambda x: False, 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
}), }),
'Momentum': edict({ 'Momentum': edict({
'learning_rate': 2e-5, 'learning_rate': 2e-5,
......
...@@ -23,6 +23,7 @@ from mindspore.ops import operations as P ...@@ -23,6 +23,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
class CrossEntropyCalculation(nn.Cell): class CrossEntropyCalculation(nn.Cell):
...@@ -123,3 +124,25 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre ...@@ -123,3 +124,25 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre
max_num = int(num) max_num = int(num)
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
return load_finetune_checkpoint_path return load_finetune_checkpoint_path
class BertLearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for Bert network.
"""
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
...@@ -30,7 +30,7 @@ verification_set = [ ...@@ -30,7 +30,7 @@ verification_set = [
'block': { 'block': {
'model': network, 'model': network,
'loss': SquaredLoss(), 'loss': SquaredLoss(),
'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01), 'opt': Lamb(network.trainable_params(), 0.02, weight_decay=0.01),
'num_epochs': num_epochs, 'num_epochs': num_epochs,
'loss_upper_bound': 0.3, 'loss_upper_bound': 0.3,
}, },
......
...@@ -31,7 +31,7 @@ Example: ...@@ -31,7 +31,7 @@ Example:
'block': { 'block': {
'model': network, 'model': network,
'loss': SquaredLoss(), 'loss': SquaredLoss(),
'opt': Lamb(network.trainable_params(), decay_steps=num_epochs, warmup_steps=10, weight_decay=0.01), 'opt': Lamb(network.trainable_params(), lr=0.02, weight_decay=0.01),
'num_epochs': num_epochs, 'num_epochs': num_epochs,
'loss_upper_bound': 0.3, 'loss_upper_bound': 0.3,
}, },
......
...@@ -22,8 +22,9 @@ import os ...@@ -22,8 +22,9 @@ import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn.optim import AdamWeightDecayDynamicLR from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn import learning_rate_schedule as lr_schedules
from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from ...dataset_mock import MindData from ...dataset_mock import MindData
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
...@@ -98,6 +99,25 @@ def get_config(version='base', batch_size=1): ...@@ -98,6 +99,25 @@ def get_config(version='base', batch_size=1):
return BertConfig(batch_size=batch_size) return BertConfig(batch_size=batch_size)
class BertLearningRate(lr_schedules.LearningRateSchedule):
def __init__(self, decay_steps, warmup_steps=0, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
def test_bert_train(): def test_bert_train():
""" """
the main function the main function
...@@ -123,7 +143,8 @@ def test_bert_train(): ...@@ -123,7 +143,8 @@ def test_bert_train():
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
net = ModelBert(netwithloss, optimizer=optimizer) net = ModelBert(netwithloss, optimizer=optimizer)
net.set_train() net.set_train()
build_construct_graph(net, *inputs, execute=False) build_construct_graph(net, *inputs, execute=False)
...@@ -147,7 +168,8 @@ def test_bert_withlossscale_train(): ...@@ -147,7 +168,8 @@ def test_bert_withlossscale_train():
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
net = ModelBert(netwithloss, optimizer=optimizer) net = ModelBert(netwithloss, optimizer=optimizer)
net.set_train() net.set_train()
build_construct_graph(net, *inputs, execute=True) build_construct_graph(net, *inputs, execute=True)
...@@ -173,7 +195,8 @@ def bert_withlossscale_manager_train(): ...@@ -173,7 +195,8 @@ def bert_withlossscale_manager_train():
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
net = ModelBert(netwithloss, optimizer=optimizer) net = ModelBert(netwithloss, optimizer=optimizer)
net.set_train() net.set_train()
build_construct_graph(net, *inputs, execute=True) build_construct_graph(net, *inputs, execute=True)
...@@ -200,7 +223,8 @@ def bert_withlossscale_manager_train_feed(): ...@@ -200,7 +223,8 @@ def bert_withlossscale_manager_train_feed():
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
net = ModelBert(netwithloss, optimizer=optimizer) net = ModelBert(netwithloss, optimizer=optimizer)
net.set_train() net.set_train()
build_construct_graph(net, *inputs, execute=True) build_construct_graph(net, *inputs, execute=True)
...@@ -24,7 +24,7 @@ cfg = edict({ ...@@ -24,7 +24,7 @@ cfg = edict({
'scale_factor': 2, 'scale_factor': 2,
'scale_window': 1000, 'scale_window': 1000,
'optimizer': 'Lamb', 'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({ 'AdamWeightDecay': edict({
'learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 1e-10, 'end_learning_rate': 1e-10,
'power': 5.0, 'power': 5.0,
...@@ -33,7 +33,7 @@ cfg = edict({ ...@@ -33,7 +33,7 @@ cfg = edict({
'warmup_steps': 10000, 'warmup_steps': 10000,
}), }),
'Lamb': edict({ 'Lamb': edict({
'start_learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 1e-10, 'end_learning_rate': 1e-10,
'power': 10.0, 'power': 10.0,
'warmup_steps': 10000, 'warmup_steps': 10000,
......
...@@ -32,7 +32,7 @@ cfg = edict({ ...@@ -32,7 +32,7 @@ cfg = edict({
'pre_training_ckpt': '/your/path/pre_training.ckpt', 'pre_training_ckpt': '/your/path/pre_training.ckpt',
'use_crf': False, 'use_crf': False,
'optimizer': 'Lamb', 'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({ 'AdamWeightDecay': edict({
'learning_rate': 2e-5, 'learning_rate': 2e-5,
'end_learning_rate': 1e-7, 'end_learning_rate': 1e-7,
'power': 1.0, 'power': 1.0,
...@@ -40,7 +40,7 @@ cfg = edict({ ...@@ -40,7 +40,7 @@ cfg = edict({
'eps': 1e-6, 'eps': 1e-6,
}), }),
'Lamb': edict({ 'Lamb': edict({
'start_learning_rate': 2e-5, 'learning_rate': 2e-5,
'end_learning_rate': 1e-7, 'end_learning_rate': 1e-7,
'power': 1.0, 'power': 1.0,
'decay_filter': lambda x: False, 'decay_filter': lambda x: False,
......
...@@ -29,9 +29,11 @@ from mindspore.nn.optim import Lamb ...@@ -29,9 +29,11 @@ from mindspore.nn.optim import Lamb
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.nn import learning_rate_schedule as lr_schedules
from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
from src.bert_model import BertConfig from src.bert_model import BertConfig
DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]
SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json"
...@@ -111,6 +113,25 @@ def weight_variable(shape): ...@@ -111,6 +113,25 @@ def weight_variable(shape):
return Tensor(ones) return Tensor(ones)
class BertLearningRate(lr_schedules.LearningRateSchedule):
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
class ModelCallback(Callback): class ModelCallback(Callback):
def __init__(self): def __init__(self):
super(ModelCallback, self).__init__() super(ModelCallback, self).__init__()
...@@ -134,9 +155,15 @@ def test_bert_tdt(): ...@@ -134,9 +155,15 @@ def test_bert_tdt():
ds = me_de_train_dataset() ds = me_de_train_dataset()
config = get_config(version='large', batch_size=16) config = get_config(version='large', batch_size=16)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), lr = BertLearningRate(decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), learning_rate=5e-5,
start_learning_rate=5e-5, end_learning_rate=1e-9, end_learning_rate=1e-9, power=10.0, warmup_steps=0)
power=10.0, warmup_steps=0, weight_decay=0.01) decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
decay_params = list(filter(decay_filter, net_with_loss.trainable_params()))
other_params = list(filter(no_decay_filter, net_with_loss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
optimizer = Lamb(group_params, lr)
scale_window = 3 scale_window = 3
scale_manager = DynamicLossScaleManager(262144, 2, scale_window) scale_manager = DynamicLossScaleManager(262144, 2, scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
......
...@@ -33,6 +33,7 @@ from mindspore.nn.optim import Lamb ...@@ -33,6 +33,7 @@ from mindspore.nn.optim import Lamb
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.model import Model from mindspore.train.model import Model
import mindspore.nn.learning_rate_schedule as lr_schedules
_current_dir = os.path.dirname(os.path.realpath(__file__)) _current_dir = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]
...@@ -125,6 +126,25 @@ def weight_variable(shape): ...@@ -125,6 +126,25 @@ def weight_variable(shape):
return Tensor(ones) return Tensor(ones)
class BertLearningRate(lr_schedules.LearningRateSchedule):
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
class ModelCallback(Callback): class ModelCallback(Callback):
def __init__(self): def __init__(self):
super(ModelCallback, self).__init__() super(ModelCallback, self).__init__()
...@@ -162,9 +182,16 @@ def test_bert_percision(): ...@@ -162,9 +182,16 @@ def test_bert_percision():
batch_size = 16 batch_size = 16
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*new_repeat_count, lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count,
start_learning_rate=5e-5, end_learning_rate=1e-9, learning_rate=5e-5, end_learning_rate=1e-9,
power=10.0, warmup_steps=0, weight_decay=0.01) power=10.0, warmup_steps=0)
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
decay_params = list(filter(decay_filter, net_with_loss.trainable_params()))
other_params = list(filter(no_decay_filter, net_with_loss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
optimizer = Lamb(group_params, lr)
scale_window = 3 scale_window = 3
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
...@@ -220,9 +247,18 @@ def test_bert_performance(): ...@@ -220,9 +247,18 @@ def test_bert_performance():
batch_size = 16 batch_size = 16
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*new_repeat_count,
start_learning_rate=5e-5, end_learning_rate=1e-9, lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count,
power=10.0, warmup_steps=0, weight_decay=0.01) learning_rate=5e-5, end_learning_rate=1e-9,
power=10.0, warmup_steps=0)
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
decay_params = list(filter(decay_filter, net_with_loss.trainable_params()))
other_params = list(filter(no_decay_filter, net_with_loss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params}]
optimizer = Lamb(group_params, lr)
scale_window = 3 scale_window = 3
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window) scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
......
...@@ -20,8 +20,10 @@ import mindspore.nn as nn ...@@ -20,8 +20,10 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter, context from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR from mindspore.nn.optim import Adam, AdamWeightDecay
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.nn.learning_rate_schedule as lr_schedules
from mindspore.nn.dynamic_lr import polynomial_decay_lr
context.set_context(enable_sparse=True) context.set_context(enable_sparse=True)
...@@ -112,6 +114,62 @@ def test_sparse_adam_compile(): ...@@ -112,6 +114,62 @@ def test_sparse_adam_compile():
_executor.compile(train_network, indices, label) _executor.compile(train_network, indices, label)
def test_adam_group1():
""" test_adam_group_lr_and_weight_decay """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
net_with_loss = WithLossCell(net, loss)
all_params = net.trainable_params()
poly_decay_lr = polynomial_decay_lr(0.01, 0.0001, total_step=10, step_per_epoch=1, decay_epoch=3, power=1.0)
group_params = [{'params': [all_params[0]], 'lr': poly_decay_lr, 'weight_decay': 0.9},
{'params': [all_params[1]]}]
optimizer = nn.Adam(group_params, learning_rate=0.1)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_adam_group2():
""" test_adam_group_lr_and_weight_decay """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
net_with_loss = WithLossCell(net, loss)
all_params = net.trainable_params()
schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0)
group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9},
{'params': [all_params[1]]}]
optimizer = nn.Adam(group_params, learning_rate=schedule_lr)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_adamweightdecay_group():
""" test_adam_group_lr_and_weight_decay """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
net_with_loss = WithLossCell(net, loss)
all_params = net.trainable_params()
schedule_lr = lr_schedules.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0)
group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9},
{'params': [all_params[1]]}]
optimizer = nn.AdamWeightDecay(group_params, learning_rate=schedule_lr)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_AdamWeightDecay_beta1(): def test_AdamWeightDecay_beta1():
net = Net() net = Net()
print("**********", net.get_parameters()) print("**********", net.get_parameters())
...@@ -131,20 +189,6 @@ def test_AdamWeightDecay_e(): ...@@ -131,20 +189,6 @@ def test_AdamWeightDecay_e():
AdamWeightDecay(net.get_parameters(), eps=-0.1, learning_rate=0.1) AdamWeightDecay(net.get_parameters(), eps=-0.1, learning_rate=0.1)
def test_AdamWeightDecayDynamicLR():
""" test_AdamWeightDecayDynamicLR """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_adam_mindspore_with_empty_params(): def test_adam_mindspore_with_empty_params():
net = nn.Flatten() net = nn.Flatten()
with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
""" test lamb """ """ test lamb """
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
...@@ -22,6 +21,27 @@ from mindspore.common.api import _executor ...@@ -22,6 +21,27 @@ from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Lamb from mindspore.nn.optim import Lamb
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
class LambLearningRate(LearningRateSchedule):
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(LambLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
return lr
class Net(nn.Cell): class Net(nn.Cell):
...@@ -51,27 +71,49 @@ class NetWithoutWeight(nn.Cell): ...@@ -51,27 +71,49 @@ class NetWithoutWeight(nn.Cell):
return x return x
def test_lamb_compile(): def test_lamb_compile_dynamic_lr():
""" test_Lamb_compile """ """ test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10) warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0)
optimizer = Lamb(net.trainable_params(), warmup_decay_lr)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lamb_error(): def test_lamb_compile():
""" test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
with pytest.raises(TypeError): net.set_train()
Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0) loss = nn.SoftmaxCrossEntropyWithLogits()
with pytest.raises(TypeError): optimizer = Lamb(net.trainable_params(), 0.02, 0.9)
Lamb(net.get_parameters(), decay_steps=1.0)
with pytest.raises(ValueError): net_with_loss = WithLossCell(net, loss)
Lamb(net.get_parameters(), decay_steps=0) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lamb_group():
""" test_Lamb_group_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0)
all_params = net.trainable_params()
group_params = [{'params': [all_params[0]], 'lr': warmup_decay_lr, 'weight_decay': 0.9},
{'params': [all_params[1]]}]
optimizer = Lamb(group_params, 0.02)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
...@@ -18,7 +18,7 @@ import pytest ...@@ -18,7 +18,7 @@ import pytest
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay
class IterableObjc: class IterableObjc:
...@@ -81,10 +81,6 @@ class TestNullParam(): ...@@ -81,10 +81,6 @@ class TestNullParam():
with pytest.raises(ValueError): with pytest.raises(ValueError):
AdamWeightDecay(None) AdamWeightDecay(None)
def test_AdamWeightDecayDynamicLR_init(self):
with pytest.raises(ValueError):
AdamWeightDecayDynamicLR(None, 10)
def test_Sgd_init(self): def test_Sgd_init(self):
with pytest.raises(ValueError): with pytest.raises(ValueError):
SGD(None) SGD(None)
...@@ -101,10 +97,6 @@ class TestUnsupportParam(): ...@@ -101,10 +97,6 @@ class TestUnsupportParam():
with pytest.raises(TypeError): with pytest.raises(TypeError):
AdamWeightDecay(9) AdamWeightDecay(9)
def test_AdamWeightDecayDynamicLR_init(self):
with pytest.raises(TypeError):
AdamWeightDecayDynamicLR(0.5, 10)
def test_Sgd_init(self): def test_Sgd_init(self):
with pytest.raises(TypeError): with pytest.raises(TypeError):
paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x") paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x")
......
...@@ -37,6 +37,7 @@ class Net(nn.Cell): ...@@ -37,6 +37,7 @@ class Net(nn.Cell):
x = self.biasAdd(self.matmul(x, self.weight), self.bias) x = self.biasAdd(self.matmul(x, self.weight), self.bias)
return x return x
class NetWithSparseGatherV2(nn.Cell): class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """ """ NetWithSparseGatherV2 definition """
def __init__(self): def __init__(self):
......
...@@ -28,7 +28,7 @@ decay_epoch = 2 ...@@ -28,7 +28,7 @@ decay_epoch = 2
min_lr = 0.01 min_lr = 0.01
max_lr = 0.1 max_lr = 0.1
power = 0.5 power = 0.5
warmup_epoch = 2
class TestInputs: class TestInputs:
def test_milestone1(self): def test_milestone1(self):
...@@ -234,3 +234,8 @@ def test_polynomial_decay(): ...@@ -234,3 +234,8 @@ def test_polynomial_decay():
lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
True) True)
assert len(lr2) == total_step assert len(lr2) == total_step
def test_warmup():
lr1 = dr.warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch)
assert len(lr1) == total_step
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Test Dynamic Learning Rate """
import pytest
from mindspore import Tensor, Parameter
from mindspore.nn import learning_rate_schedule as lr_schedules
from mindspore.common.api import _executor
import mindspore.common.dtype as mstype
learning_rate = 0.1
end_learning_rate = 0.01
decay_rate = 0.9
decay_steps = 4
warmup_steps = 2
min_lr = 0.01
max_lr = 0.1
power = 0.5
global_step = Parameter(Tensor(2, mstype.int32), 'global_step')
class TestInit:
def test_learning_rate_type(self):
lr = True
with pytest.raises(TypeError):
lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps)
with pytest.raises(TypeError):
lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power)
def test_learning_rate_value(self):
lr = -1.0
with pytest.raises(ValueError):
lr_schedules.ExponentialDecayLR(lr, decay_rate, decay_steps)
with pytest.raises(ValueError):
lr_schedules.PolynomialDecayLR(lr, end_learning_rate, decay_steps, power)
def test_end_learning_rate_type(self):
lr = True
with pytest.raises(TypeError):
lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power)
def test_end_learning_rate_value(self):
lr = -1.0
with pytest.raises(ValueError):
lr_schedules.PolynomialDecayLR(learning_rate, lr, decay_steps, power)
def test_decay_rate_type(self):
rate = 'a'
with pytest.raises(TypeError):
lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps)
def test_decay_rate_value(self):
rate = -1.0
with pytest.raises(ValueError):
lr_schedules.ExponentialDecayLR(learning_rate, rate, decay_steps)
def test_decay_steps_type(self):
decay_steps_e = 'm'
with pytest.raises(TypeError):
lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e)
with pytest.raises(TypeError):
lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e)
with pytest.raises(TypeError):
lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power)
def test_decay_steps_value(self):
decay_steps_e = -2
with pytest.raises(ValueError):
lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps_e)
with pytest.raises(ValueError):
lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps_e)
with pytest.raises(ValueError):
lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps_e, power)
def test_is_stair(self):
is_stair = 1
with pytest.raises(TypeError):
lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, is_stair)
def test_min_lr_type(self):
min_lr1 = True
with pytest.raises(TypeError):
lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps)
def test_min_lr_value(self):
min_lr1 = -1.0
with pytest.raises(ValueError):
lr_schedules.CosineDecayLR(min_lr1, max_lr, decay_steps)
def test_max_lr_type(self):
max_lr1 = 'a'
with pytest.raises(TypeError):
lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps)
def test_max_lr_value(self):
max_lr1 = -1.0
with pytest.raises(ValueError):
lr_schedules.CosineDecayLR(min_lr, max_lr1, decay_steps)
def test_power(self):
power1 = True
with pytest.raises(TypeError):
lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power1)
def test_exponential_decay():
lr_schedule = lr_schedules.ExponentialDecayLR(learning_rate, decay_rate, decay_steps, True)
_executor.compile(lr_schedule, global_step)
def test_enatural_exp_decay():
lr_schedule = lr_schedules.NaturalExpDecayLR(learning_rate, decay_rate, decay_steps, True)
_executor.compile(lr_schedule, global_step)
def test_inverse_decay():
lr_schedule = lr_schedules.InverseDecayLR(learning_rate, decay_rate, decay_steps, True)
_executor.compile(lr_schedule, global_step)
def test_cosine_decay():
lr_schedule = lr_schedules.CosineDecayLR(min_lr, max_lr, decay_steps)
_executor.compile(lr_schedule, global_step)
def test_polynomial_decay():
lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
_executor.compile(lr_schedule, global_step)
def test_polynomial_decay2():
lr_schedule = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power, True)
_executor.compile(lr_schedule, global_step)
def test_warmup():
lr_schedule = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
_executor.compile(lr_schedule, global_step)
...@@ -152,7 +152,7 @@ def test_compile_fp16_overflow(): ...@@ -152,7 +152,7 @@ def test_compile_fp16_overflow():
net = NetFP16(16, 16) net = NetFP16(16, 16)
loss = MSELoss() loss = MSELoss()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) optimizer = Lamb(net.trainable_params(), learning_rate=0.01)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
train_network.set_train() train_network.set_train()
......
...@@ -104,9 +104,11 @@ def test_group_dynamic_1(): ...@@ -104,9 +104,11 @@ def test_group_dynamic_1():
assert opt.is_group_params_ordered is True assert opt.is_group_params_ordered is True
for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()):
if param in conv_params: if param in conv_params:
assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) assert np.all(lr.learning_rate.data.asnumpy() == \
Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy())
else: else:
assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) assert np.all(lr.learning_rate.data.asnumpy() == \
Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy())
assert param.name == order_param.name assert param.name == order_param.name
...@@ -134,9 +136,11 @@ def test_group_dynamic_2(): ...@@ -134,9 +136,11 @@ def test_group_dynamic_2():
assert opt.dynamic_lr is True assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters): for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params: if param in conv_params:
assert np.all(lr.data.asnumpy() == Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy()) assert np.all(lr.learning_rate.data.asnumpy() == \
Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy())
else: else:
assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy()) assert np.all(lr.learning_rate.data.asnumpy() == \
Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy())
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt) train_network = TrainOneStepCell(net_with_loss, opt)
...@@ -157,7 +161,7 @@ def test_group_dynamic_no_same_size(): ...@@ -157,7 +161,7 @@ def test_group_dynamic_no_same_size():
def test_group_not_float_lr(): def test_group_not_float_lr():
net = LeNet5() net = LeNet5()
conv_lr = 1 conv_lr = np.array(1)
default_lr = 0.3 default_lr = 0.3
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
...@@ -169,7 +173,7 @@ def test_group_not_float_lr(): ...@@ -169,7 +173,7 @@ def test_group_not_float_lr():
def test_group_not_float_weight_decay(): def test_group_not_float_weight_decay():
net = LeNet5() net = LeNet5()
conv_weight_decay = 1 conv_weight_decay = np.array(1)
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
...@@ -238,11 +242,15 @@ def test_get_lr_parameter_with_group(): ...@@ -238,11 +242,15 @@ def test_get_lr_parameter_with_group():
assert opt.is_group_lr is True assert opt.is_group_lr is True
for param in opt.parameters: for param in opt.parameters:
lr = opt.get_lr_parameter(param) lr = opt.get_lr_parameter(param)
assert lr.name == 'lr_' + param.name if 'conv' in param.name:
cur_name = 'learning_rate_group_' + '0'
else:
cur_name = 'learning_rate_group_' + '1'
assert lr.name == cur_name
lr_list = opt.get_lr_parameter(conv_params) lr_list = opt.get_lr_parameter(conv_params)
for lr, param in zip(lr_list, conv_params): for lr, param in zip(lr_list, conv_params):
assert lr.name == 'lr_' + param.name assert lr.name == 'learning_rate_group_' + '0'
def test_get_lr_parameter_with_order_group(): def test_get_lr_parameter_with_order_group():
...@@ -256,7 +264,11 @@ def test_get_lr_parameter_with_order_group(): ...@@ -256,7 +264,11 @@ def test_get_lr_parameter_with_order_group():
assert opt.is_group_lr is True assert opt.is_group_lr is True
for param in opt.parameters: for param in opt.parameters:
lr = opt.get_lr_parameter(param) lr = opt.get_lr_parameter(param)
assert lr.name == 'lr_' + param.name if 'conv' in param.name:
cur_name = 'learning_rate_group_' + '0'
else:
cur_name = 'learning_rate'
assert lr.name == cur_name
def test_get_lr_parameter_with_no_group(): def test_get_lr_parameter_with_no_group():
...@@ -271,7 +283,7 @@ def test_get_lr_parameter_with_no_group(): ...@@ -271,7 +283,7 @@ def test_get_lr_parameter_with_no_group():
assert opt.is_group_lr is False assert opt.is_group_lr is False
for param in opt.parameters: for param in opt.parameters:
lr = opt.get_lr_parameter(param) lr = opt.get_lr_parameter(param)
assert lr.name == opt.learning_rate.name assert lr.name == 'learning_rate'
params_error = [1, 2, 3] params_error = [1, 2, 3]
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -305,7 +317,11 @@ def test_order_params_1(): ...@@ -305,7 +317,11 @@ def test_order_params_1():
assert decay_flags is False assert decay_flags is False
assert param.name == order_param.name assert param.name == order_param.name
assert lr.name == 'lr_' + param.name if 'conv' in param.name:
assert lr.name == 'learning_rate'
elif 'bias' in param.name:
assert lr.name == 'learning_rate_group_' + '1'
def test_order_params_2(): def test_order_params_2():
...@@ -323,8 +339,9 @@ def test_order_params_2(): ...@@ -323,8 +339,9 @@ def test_order_params_2():
assert opt.is_group is True assert opt.is_group is True
assert opt.is_group_lr is True assert opt.is_group_lr is True
assert opt.is_group_params_ordered is True assert opt.is_group_params_ordered is True
all_lr = opt.get_lr_parameter(fc1_params+conv_params)
for weight_decay, decay_flags, lr, param, order_param in zip( for weight_decay, decay_flags, lr, param, order_param in zip(
opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, fc1_params+conv_params): opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params+conv_params):
if param in conv_params: if param in conv_params:
assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy()) assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy())
assert weight_decay == conv_weight_decay assert weight_decay == conv_weight_decay
...@@ -339,8 +356,10 @@ def test_order_params_2(): ...@@ -339,8 +356,10 @@ def test_order_params_2():
assert decay_flags is False assert decay_flags is False
assert param.name == order_param.name assert param.name == order_param.name
assert lr.name == 'lr_' + param.name if 'conv' in param.name:
assert lr.name == 'learning_rate'
elif 'fc1' in param.name:
assert lr.name == 'learning_rate_group_' + '0'
def test_get_order_params_with_not_same(): def test_get_order_params_with_not_same():
net = LeNet5() net = LeNet5()
......
...@@ -20,7 +20,7 @@ import mindspore.nn as nn ...@@ -20,7 +20,7 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context from mindspore import context
...@@ -51,23 +51,8 @@ class Net(nn.Cell): ...@@ -51,23 +51,8 @@ class Net(nn.Cell):
return s return s
def test_AdamWeightDecayDynamicLR():
""" test_AdamWeightDecayDynamicLR """
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_AdamWeightDecay(): def test_AdamWeightDecay():
""" test_AdamWeightDecayDynamicLR """ """ test_AdamWeightDecay """
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
...@@ -89,7 +74,7 @@ def test_lamb_compile(): ...@@ -89,7 +74,7 @@ def test_lamb_compile():
net = Net() net = Net()
net.set_train() net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10) optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
...@@ -102,9 +87,9 @@ def test_edge_case(): ...@@ -102,9 +87,9 @@ def test_edge_case():
net = Net() net = Net()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
context.set_auto_parallel_context(parallel_mode="stand_alone") context.set_auto_parallel_context(parallel_mode="stand_alone")
Lamb(net.trainable_params(), decay_steps=10) Lamb(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
Adam(net.trainable_params(), learning_rate=0.1) Adam(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
context.set_auto_parallel_context(device_num=16) context.set_auto_parallel_context(device_num=16)
Lamb(net.trainable_params(), decay_steps=10) Lamb(net.trainable_params(), learning_rate=0.1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册