diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 05560d9739a897bf3d5b66ed1294df2179288e53..788a7d27543cd17d342dcb495983bef0afb99be6 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -49,7 +49,15 @@ class Optimizer(Cell): applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. Args: - learning_rate (float): A floating point value for the learning rate. Should be greater than 0. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is + Iterable or a Tensor and the dims of the Tensor is 1, + use dynamic learning rate, then the i-th step will + take the i-th value as the learning rate. + When the learning_rate is float or learning_rate is a Tensor + but the dims of the Tensor is 0, use fixed learning rate. + Other cases are not supported. Should be greater than 0. + If the type of `learning_rate` input is int, it will be + converted to float. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, the "params", "lr" and "weight_decay" are the keys can be parsed. @@ -96,6 +104,8 @@ class Optimizer(Cell): self.is_group = False self.is_group_lr = False self.loss_scale = loss_scale + if isinstance(learning_rate, int): + learning_rate = float(learning_rate) if isinstance(learning_rate, float): self.dynamic_lr = False self.gather = None diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index b802c8e7a50bed1c4164ca6de3ccafca27e91f29..bf492445502cd5773702f3ff1f6039bc6fff2eb6 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -106,22 +106,26 @@ class SGD(Optimizer): super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale) + if isinstance(momentum, int): + momentum = float(momentum) if not isinstance(momentum, float): raise TypeError("momentum should be float number!") if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) - if not isinstance(dampening, float): - raise TypeError("dampening should be float number") - if isinstance(dampening, int): dampening = float(dampening) + if not isinstance(dampening, float): + raise TypeError("dampening should be float number") if dampening < 0.0: raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) self.dampening = dampening + if isinstance(weight_decay, int): + weight_decay = float(weight_decay) + validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) self.nesterov = nesterov diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 10ac0f8d29d910dee993431007d7de3af8da5a1c..a4e998589e121c7be18c2d41d3674260458a05dc 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2591,8 +2591,7 @@ class Adam(PrimitiveWithInfer): Mean square gradients, has the same type as `var`. - **beta1_power** (float) - :math:`beta_1^t` in the updating formula. - **beta2_power** (float) - :math:`beta_2^t` in the updating formula. - - **lr** (Union[float, Tensor, Iterable]) - :math:`l` in the updating formula. - Iterable type is used for the dynamic learning rate. + - **lr** (float) - :math:`l` in the updating formula. - **beta1** (float) - The exponential decay rate for the 1st moment estimates. - **beta2** (float) - The exponential decay rate for the 2nd moment estimates. - **epsilon** (float) - Term added to the denominator to improve numerical stability.