提交 74bbfa3c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3095 modify the limit of loss scale

Merge pull request !3095 from Simson/push-to-opensource
......@@ -79,9 +79,10 @@ class Optimizer(Cell):
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' should be in one of group parameters.
weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0.
weight_decay (float): A floating point value for the weight decay. It should be not less than 0 and not
greater than 1.
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the
loss_scale (float): A floating point value for the loss scale. It should be not less than 1. If the
type of `loss_scale` input is int, it will be converted to float. Default: 1.0.
Raises:
......@@ -103,12 +104,12 @@ class Optimizer(Cell):
if isinstance(loss_scale, int):
loss_scale = float(loss_scale)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
if isinstance(weight_decay, int):
weight_decay = float(weight_decay)
validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, 1.0, Rel.INC_BOTH, self.cls_name)
self.is_group = False
self.is_group_lr = False
......
......@@ -98,7 +98,7 @@ def test_momentum_with_loss_scale():
net = Net(strategy1, strategy2, weight)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=0.5)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=1.0)
net_with_loss = NetWithLoss(net, strategy3)
......@@ -169,7 +169,7 @@ def test_momentum_with_loss_scale_and_dynamic_lr():
net = Net(strategy1, strategy2, weight)
lr = Tensor(np.ones([6]), dtype=ms.float32)
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=0.5)
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=1.0)
net_with_loss = NetWithLoss(net, strategy3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册