From 177e18f3f46fe3fe5b14c149c23a16dd44a25dd0 Mon Sep 17 00:00:00 2001 From: simson <526422051@qq.com> Date: Wed, 15 Jul 2020 16:45:48 +0800 Subject: [PATCH] modify the limit of loss scale --- mindspore/nn/optim/optimizer.py | 9 +++++---- tests/ut/python/parallel/test_loss_and_optimizer.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index cdf1565f3..4364e8c5a 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -79,9 +79,10 @@ class Optimizer(Cell): the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which in the value of 'order_params' should be in one of group parameters. - weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. + weight_decay (float): A floating point value for the weight decay. It should be not less than 0 and not + greater than 1. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. - loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the + loss_scale (float): A floating point value for the loss scale. It should be not less than 1. If the type of `loss_scale` input is int, it will be converted to float. Default: 1.0. Raises: @@ -103,12 +104,12 @@ class Optimizer(Cell): if isinstance(loss_scale, int): loss_scale = float(loss_scale) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) - validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) + validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) if isinstance(weight_decay, int): weight_decay = float(weight_decay) validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, 1.0, Rel.INC_BOTH, self.cls_name) self.is_group = False self.is_group_lr = False diff --git a/tests/ut/python/parallel/test_loss_and_optimizer.py b/tests/ut/python/parallel/test_loss_and_optimizer.py index 91be7682a..615f058dc 100644 --- a/tests/ut/python/parallel/test_loss_and_optimizer.py +++ b/tests/ut/python/parallel/test_loss_and_optimizer.py @@ -98,7 +98,7 @@ def test_momentum_with_loss_scale(): net = Net(strategy1, strategy2, weight) - optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=0.5) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=1.0) net_with_loss = NetWithLoss(net, strategy3) @@ -169,7 +169,7 @@ def test_momentum_with_loss_scale_and_dynamic_lr(): net = Net(strategy1, strategy2, weight) lr = Tensor(np.ones([6]), dtype=ms.float32) - optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=0.5) + optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=1.0) net_with_loss = NetWithLoss(net, strategy3) -- GitLab