From c95215bca05677d1b862233f2519ff34e2bf5d37 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Mon, 11 May 2020 17:37:37 +0800 Subject: [PATCH] seperate lr groups and weight_decay groups --- mindspore/nn/optim/adam.py | 2 +- mindspore/nn/optim/momentum.py | 2 +- mindspore/nn/optim/optimizer.py | 15 ++++++++++++--- mindspore/nn/optim/rmsprop.py | 4 ++-- mindspore/nn/optim/sgd.py | 2 +- .../test_optimize_with_parameter_groups.py | 14 ++++++++------ 6 files changed, 25 insertions(+), 14 deletions(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index ef9676332..2138aed74 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -243,7 +243,7 @@ class Adam(Optimizer): self.beta1_power = beta1_power beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, self.beta2, self.eps), lr, gradients, params, moment1, moment2) diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 7cfbf1118..166e8ae29 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -111,7 +111,7 @@ class Momentum(Optimizer): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) else: success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 671e92de3..6f7f60a21 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -94,6 +94,7 @@ class Optimizer(Cell): validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) self.is_group = False + self.is_group_lr = False self.loss_scale = loss_scale if isinstance(learning_rate, float): self.dynamic_lr = False @@ -116,14 +117,17 @@ class Optimizer(Cell): self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) - if self.is_group: + if self.is_group_lr: self.learning_rate = ParameterTuple(self.group_lr) + else: + self.learning_rate = Parameter(learning_rate, name="learning_rate") + + if self.is_group: self.parameters = ParameterTuple(self.params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) else: - self.learning_rate = Parameter(learning_rate, name="learning_rate") self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name @@ -207,6 +211,7 @@ class Optimizer(Cell): for group_param in parameters: lr_length = dynamic_lr_length if 'lr' in group_param.keys(): + self.is_group_lr = True self._get_single_lr(group_param['lr']) if isinstance(group_param['lr'], Iterable): lr_length = len(group_param['lr']) @@ -247,6 +252,10 @@ class Optimizer(Cell): else: weight_decay_ = weight_decay * self.loss_scale + for key in group_param.keys(): + if key not in ('params', 'lr', 'weight_decay'): + logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") + for param in group_param['params']: if param in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") @@ -261,7 +270,7 @@ class Optimizer(Cell): Returns: float, the learning rate of current step. """ - if self.is_group: + if self.is_group_lr: lr = self.learning_rate if self.dynamic_lr: lr = () diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index b96d9499b..bd4edac4d 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -176,7 +176,7 @@ class RMSProp(Optimizer): gradients = self.scale_grad(gradients) lr = self.get_lr() if self.centered: - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) else: @@ -184,7 +184,7 @@ class RMSProp(Optimizer): self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) else: - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, self.momentum), lr, params, self.ms, self.moment, gradients) else: diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 0db58af85..c2575be6c 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -139,7 +139,7 @@ class SGD(Optimizer): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - if self.is_group: + if self.is_group_lr: success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) else: success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py index 8dd98990f..24ee9254a 100644 --- a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py @@ -65,12 +65,13 @@ def test_group_lr(): opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) assert opt.is_group is True + assert opt.is_group_lr is True assert opt.dynamic_lr is False for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert lr.data == Tensor(conv_lr, mstype.float32) + assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy()) else: - assert lr.data == Tensor(default_lr, mstype.float32) + assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -96,9 +97,9 @@ def test_group_dynamic_1(): assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32)) + assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) else: - assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32)) + assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -124,9 +125,9 @@ def test_group_dynamic_2(): assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)) + assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))) else: - assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)) + assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -184,6 +185,7 @@ def test_weight_decay(): opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) assert opt.is_group is True + assert opt.is_group_lr is False for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): if param in conv_params: assert weight_decay == conv_weight_decay -- GitLab