diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index ef96763323e201183d3c4af133df7023835ca207..2138aed7418687fd8474d68f2b6b7ed56f88527c 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -243,7 +243,7 @@ class Adam(Optimizer): self.beta1_power = beta1_power beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, self.beta2, self.eps), lr, gradients, params, moment1, moment2) diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 7cfbf1118333c60cb8d50611b018e6f01c8b4346..166e8ae29610bef4bd7b19aebc4fe0353881e9d2 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -111,7 +111,7 @@ class Momentum(Optimizer): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) else: success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 671e92de3af545c81d601f15e4d17620aeb46a8d..6f7f60a2166a7f23f35e21e3b7c6bdd0ae2354d2 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -94,6 +94,7 @@ class Optimizer(Cell): validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) self.is_group = False + self.is_group_lr = False self.loss_scale = loss_scale if isinstance(learning_rate, float): self.dynamic_lr = False @@ -116,14 +117,17 @@ class Optimizer(Cell): self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) - if self.is_group: + if self.is_group_lr: self.learning_rate = ParameterTuple(self.group_lr) + else: + self.learning_rate = Parameter(learning_rate, name="learning_rate") + + if self.is_group: self.parameters = ParameterTuple(self.params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) else: - self.learning_rate = Parameter(learning_rate, name="learning_rate") self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name @@ -207,6 +211,7 @@ class Optimizer(Cell): for group_param in parameters: lr_length = dynamic_lr_length if 'lr' in group_param.keys(): + self.is_group_lr = True self._get_single_lr(group_param['lr']) if isinstance(group_param['lr'], Iterable): lr_length = len(group_param['lr']) @@ -247,6 +252,10 @@ class Optimizer(Cell): else: weight_decay_ = weight_decay * self.loss_scale + for key in group_param.keys(): + if key not in ('params', 'lr', 'weight_decay'): + logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") + for param in group_param['params']: if param in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") @@ -261,7 +270,7 @@ class Optimizer(Cell): Returns: float, the learning rate of current step. """ - if self.is_group: + if self.is_group_lr: lr = self.learning_rate if self.dynamic_lr: lr = () diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index b96d9499b2a6599092caafe21c641972a05c35a4..bd4edac4d1bed04aec763487c8e0756a282b2e94 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -176,7 +176,7 @@ class RMSProp(Optimizer): gradients = self.scale_grad(gradients) lr = self.get_lr() if self.centered: - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) else: @@ -184,7 +184,7 @@ class RMSProp(Optimizer): self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) else: - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, self.momentum), lr, params, self.ms, self.moment, gradients) else: diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 0db58af8556210e60da296368e2c5ab055dbd690..c2575be6c8d9338a92b8c5802397407eddf57f19 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -139,7 +139,7 @@ class SGD(Optimizer): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) else: success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py index 8dd98990fa8188e60557c9270fa96639ad51a1de..24ee9254a99136ac367b706c3ba2b7b82ddb6ea5 100644 --- a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py @@ -65,12 +65,13 @@ def test_group_lr(): opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) assert opt.is_group is True + assert opt.is_group_lr is True assert opt.dynamic_lr is False for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert lr.data == Tensor(conv_lr, mstype.float32) + assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy()) else: - assert lr.data == Tensor(default_lr, mstype.float32) + assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -96,9 +97,9 @@ def test_group_dynamic_1(): assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32)) + assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) else: - assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32)) + assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -124,9 +125,9 @@ def test_group_dynamic_2(): assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)) + assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))) else: - assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)) + assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -184,6 +185,7 @@ def test_weight_decay(): opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) assert opt.is_group is True + assert opt.is_group_lr is False for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): if param in conv_params: assert weight_decay == conv_weight_decay