提交 713ed15b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1332 Fix some bugs for optimizer.

Merge pull request !1332 from liuxiao/fix-bug-for-optimizer
...@@ -49,7 +49,15 @@ class Optimizer(Cell): ...@@ -49,7 +49,15 @@ class Optimizer(Cell):
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args: Args:
learning_rate (float): A floating point value for the learning rate. Should be greater than 0. learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Should be greater than 0.
If the type of `learning_rate` input is int, it will be
converted to float.
parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`,
the "params", "lr" and "weight_decay" are the keys can be parsed. the "params", "lr" and "weight_decay" are the keys can be parsed.
...@@ -96,6 +104,8 @@ class Optimizer(Cell): ...@@ -96,6 +104,8 @@ class Optimizer(Cell):
self.is_group = False self.is_group = False
self.is_group_lr = False self.is_group_lr = False
self.loss_scale = loss_scale self.loss_scale = loss_scale
if isinstance(learning_rate, int):
learning_rate = float(learning_rate)
if isinstance(learning_rate, float): if isinstance(learning_rate, float):
self.dynamic_lr = False self.dynamic_lr = False
self.gather = None self.gather = None
......
...@@ -106,22 +106,26 @@ class SGD(Optimizer): ...@@ -106,22 +106,26 @@ class SGD(Optimizer):
super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale) super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, int):
momentum = float(momentum)
if not isinstance(momentum, float): if not isinstance(momentum, float):
raise TypeError("momentum should be float number!") raise TypeError("momentum should be float number!")
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
if not isinstance(dampening, float):
raise TypeError("dampening should be float number")
if isinstance(dampening, int): if isinstance(dampening, int):
dampening = float(dampening) dampening = float(dampening)
if not isinstance(dampening, float):
raise TypeError("dampening should be float number")
if dampening < 0.0: if dampening < 0.0:
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
self.dampening = dampening self.dampening = dampening
if isinstance(weight_decay, int):
weight_decay = float(weight_decay)
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
self.nesterov = nesterov self.nesterov = nesterov
......
...@@ -2591,8 +2591,7 @@ class Adam(PrimitiveWithInfer): ...@@ -2591,8 +2591,7 @@ class Adam(PrimitiveWithInfer):
Mean square gradients, has the same type as `var`. Mean square gradients, has the same type as `var`.
- **beta1_power** (float) - :math:`beta_1^t` in the updating formula. - **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
- **beta2_power** (float) - :math:`beta_2^t` in the updating formula. - **beta2_power** (float) - :math:`beta_2^t` in the updating formula.
- **lr** (Union[float, Tensor, Iterable]) - :math:`l` in the updating formula. - **lr** (float) - :math:`l` in the updating formula.
Iterable type is used for the dynamic learning rate.
- **beta1** (float) - The exponential decay rate for the 1st moment estimates. - **beta1** (float) - The exponential decay rate for the 1st moment estimates.
- **beta2** (float) - The exponential decay rate for the 2nd moment estimates. - **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
- **epsilon** (float) - Term added to the denominator to improve numerical stability. - **epsilon** (float) - Term added to the denominator to improve numerical stability.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册