diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 6142b126cd81be9b05a2607851f74c851114d120..ef96763323e201183d3c4af133df7023835ca207 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -103,9 +103,9 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) -@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", +@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1, +def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2): """Apply adam optimizer to the weight parameter using Tensor.""" success = True @@ -136,9 +136,27 @@ class Adam(Optimizer): `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`, :math:`\epsilon` represents `eps`. + Note: + The Adam optimizer supports separating parameter groups. Different parameter groups can set different + `learning_rate` and `weight_decay`. + + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be + applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr" and "weight_decay" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -161,8 +179,6 @@ class Adam(Optimizer): weight_decay (float): Weight decay (L2 penalty). Default: 0.0. loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: 1.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -172,15 +188,26 @@ class Adam(Optimizer): Examples: >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.Adam(params=net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, + >>> {'params': no_conv_params}] + >>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 + >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a + >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, - use_nesterov=False, weight_decay=0.0, loss_scale=1.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) + use_nesterov=False, weight_decay=0.0, loss_scale=1.0): + super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) @@ -216,10 +243,14 @@ class Adam(Optimizer): self.beta1_power = beta1_power beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power - success = self.hyper_map(F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power, self.beta1, - self.beta2, self.eps), - gradients, params, moment1, moment2) - + if self.is_group: + success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, + self.beta2, self.eps), + lr, gradients, params, moment1, moment2) + else: + success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, + self.beta2, self.eps, lr), + gradients, params, moment1, moment2) return success @@ -262,6 +293,8 @@ class AdamWeightDecay(Optimizer): def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(AdamWeightDecay, self).__init__(learning_rate, params) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) @@ -330,6 +363,8 @@ class AdamWeightDecayDynamicLR(Optimizer): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name, warmup_steps=0): super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index ccc1b3f10be9d688ff45ebc910ff9e878df16d1d..33edafa4e2d13fb7093b6767871955c8f36b2042 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -96,7 +96,8 @@ class FTRL(Optimizer): def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, use_locking=False, loss_scale=1.0, weight_decay=0.0): super(FTRL, self).__init__(learning_rate, params) - + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 97a81a590b2c9e4d7060a57077fe66caccaece1f..b4d478f52ab38be605719c9f9dfa124dcd7b3240 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -183,6 +183,8 @@ class Lamb(Optimizer): decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): super(Lamb, self).__init__(start_learning_rate, params) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, power, beta1, beta2, eps, weight_decay, self.cls_name) diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 67de590c5feb12952ffe51b42ef262d1a40ec856..7cfbf1118333c60cb8d50611b018e6f01c8b4346 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -23,7 +23,7 @@ momentum_opt = C.MultitypeFuncGraph("momentum_opt") @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): +def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): """Apply momentum optimizer to the weight parameter using Tensor.""" success = True success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) @@ -36,9 +36,27 @@ class Momentum(Optimizer): Refer to the paper on the importance of initialization and momentum in deep learning for more details. + Note: + The Momentum optimizer supports separating parameter groups. Different parameter groups can set different + `learning_rate` and `weight_decay`. + + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be + applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` - should be class mindspore.Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr" and "weight_decay" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -49,8 +67,6 @@ class Momentum(Optimizer): momentum (float): Hyperparameter of type float, means momentum for the moving average. weight_decay (float): Weight decay (L2 penalty). Default: 0.0. loss_scale (float): A floating point value for the loss scale. Default: 1.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'beta' not in x.name and 'gamma' not in x.name. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -63,13 +79,24 @@ class Momentum(Optimizer): Examples: >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, + >>> {'params': no_conv_params}] + >>> opt = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) + >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 + >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a + >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) """ - def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) + def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0): + super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") @@ -84,5 +111,8 @@ class Momentum(Optimizer): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) + if self.is_group: + success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) + else: + success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 34abc2b1c2ab1eef765c0468006d60b150f9451d..671e92de3af545c81d601f15e4d17620aeb46a8d 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -28,7 +28,6 @@ from mindspore._checkparam import Rel from mindspore.common.tensor import Tensor from mindspore import log as logger - __all__ = ['Optimizer'] @@ -42,68 +41,96 @@ class Optimizer(Cell): This class defines the API to add Ops to train a model. Never use this class directly, but instead instantiate one of its subclasses. + Some optimizers support separating parameter groups. Different parameter groups can set different + `learning_rate` and `weight_decay`. + + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be + applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + Args: learning_rate (float): A floating point value for the learning rate. Should be greater than 0. - parameters (list): A list of parameter, which will be updated. The element in `parameters` - should be class mindspore.Parameter. + parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be + updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, + the "params", "lr" and "weight_decay" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. - If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0. + If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the - type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda - x: 'beta' not in x.name and 'gamma' not in x.name. + type of `loss_scale` input is int, it will be converted to float. Default: 1.0. Raises: ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1. TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable. """ - def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): super(Optimizer, self).__init__(auto_prefix=False) + if parameters and not isinstance(parameters, list): + parameters = list(parameters) + + if not parameters: + raise ValueError("Optimizer got an empty parameter list.") + + if not isinstance(parameters[0], (dict, Parameter)): + raise ValueError("Only a list of Parameter or dict can be supported.") + + if isinstance(loss_scale, int): + loss_scale = float(loss_scale) + validator.check_value_type("loss_scale", loss_scale, [float], None) + validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) + + if isinstance(weight_decay, int): + weight_decay = float(weight_decay) + validator.check_value_type("weight_decay", weight_decay, [float], None) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) + + self.is_group = False + self.loss_scale = loss_scale if isinstance(learning_rate, float): self.dynamic_lr = False self.gather = None self.assignadd = None self.global_step = None - validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) - learning_rate = Tensor(learning_rate, mstype.float32) + self.scalar_lr = learning_rate else: self.dynamic_lr = True self.gather = P.GatherV2() self.assignadd = P.AssignAdd() self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') - if isinstance(learning_rate, Iterable): - learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32)) - elif isinstance(learning_rate, Tensor): - if learning_rate.dim() > 1: - raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," - f"but got {learning_rate.dim()}.") - if learning_rate.dim() == 1 and learning_rate.size() < 2: - logger.warning("If want to use the dynamic learning rate, please make sure that the number " - "of elements in the list, tuple or tensor passed is greater than 1.") - else: - raise TypeError("Learning rate should be float, Tensor or Iterable.") - - if isinstance(weight_decay, int): - weight_decay = float(weight_decay) - validator.check_value_type("weight_decay", weight_decay, [float], None) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) - - if isinstance(loss_scale, int): - loss_scale = float(loss_scale) - validator.check_value_type("loss_scale", loss_scale, [float], None) - validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) - - self.loss_scale = loss_scale - self.learning_rate = Parameter(learning_rate, name="learning_rate") - self.parameters = ParameterTuple(parameters) + self.scalar_lr = None + + learning_rate = self._get_single_lr(learning_rate) + if isinstance(parameters[0], dict): + self.is_group = True + self.params = [] + self.group_lr = [] + self.group_weight_decay = [] + self._init_group_params(parameters, learning_rate, weight_decay) + + if self.is_group: + self.learning_rate = ParameterTuple(self.group_lr) + self.parameters = ParameterTuple(self.params) + self.weight_decay = tuple(self.group_weight_decay) + decay_filter = lambda x: x > 0 + self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) + else: + self.learning_rate = Parameter(learning_rate, name="learning_rate") + self.parameters = ParameterTuple(parameters) + self.weight_decay = weight_decay * loss_scale + decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name + self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale - self.weight_decay = weight_decay * loss_scale - self.decay_flags = tuple(decay_filter(x) for x in self.parameters) - - if not self.parameters: - raise ValueError("optimizer got an empty parameter list.") + self.exec_weight_decay = any(self.decay_flags) + self.param_length = len(self.parameters) def decay_weight(self, gradients): """ @@ -118,9 +145,15 @@ class Optimizer(Cell): Returns: tuple[Tensor], The gradients after weight decay. """ - if self.weight_decay > 0: - params = self.parameters - gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) + params = self.parameters + if self.is_group: + if self.exec_weight_decay: + gradients = self.hyper_map(F.partial(apply_decay), self.weight_decay, self.decay_flags, + params, gradients) + else: + if self.weight_decay > 0: + gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, + params, gradients) return gradients @@ -144,6 +177,83 @@ class Optimizer(Cell): return gradients + def _get_single_lr(self, learning_rate): + """Get learning rate in Tensor type.""" + if isinstance(learning_rate, float): + validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + lr = Tensor(learning_rate, mstype.float32) + elif isinstance(learning_rate, Iterable): + lr = Tensor(np.array(list(learning_rate)).astype(np.float32)) + elif isinstance(learning_rate, Tensor): + if learning_rate.dim() > 1: + raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," + f"but got {learning_rate.dim()}.") + if learning_rate.dim() == 1 and learning_rate.size() < 2: + logger.warning("If want to use the dynamic learning rate, please make sure that the number " + "of elements in the list, tuple or tensor passed is greater than 1.") + lr = learning_rate + else: + raise TypeError("Learning rate should be float, Tensor or Iterable.") + return lr + + def _init_group_params(self, parameters, learning_rate, weight_decay): + """Init learning rate or weight decay in group params.""" + origin_dynamic_lr = self.dynamic_lr + if self.dynamic_lr: + dynamic_lr_length = learning_rate.size() + else: + dynamic_lr_length = 0 + + for group_param in parameters: + lr_length = dynamic_lr_length + if 'lr' in group_param.keys(): + self._get_single_lr(group_param['lr']) + if isinstance(group_param['lr'], Iterable): + lr_length = len(group_param['lr']) + self.dynamic_lr = True + elif isinstance(group_param['lr'], Tensor): + lr_length = group_param['lr'].size() + self.dynamic_lr = True + if dynamic_lr_length not in (lr_length, 0): + raise ValueError("The dynamic learning rate in group should be the same size.") + dynamic_lr_length = lr_length + + if self.dynamic_lr and not origin_dynamic_lr: + self.gather = P.GatherV2() + self.assignadd = P.AssignAdd() + self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') + + params_store = [] + for group_param in parameters: + self.params += group_param['params'] + if 'lr' in group_param.keys(): + params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor)) + + if self.dynamic_lr and not params_dynamic_lr: + lr = Tensor(np.array([group_param['lr']] * dynamic_lr_length).astype(np.float32)) + else: + lr = self._get_single_lr(group_param['lr']) + else: + if self.dynamic_lr and not origin_dynamic_lr: + lr = Tensor(np.array([self.scalar_lr] * dynamic_lr_length).astype(np.float32)) + else: + lr = learning_rate + + if 'weight_decay' in group_param.keys(): + validator.check_float_legal_value('weight_decay', group_param['weight_decay'], None) + validator.check_number_range('weight_decay', group_param['weight_decay'], 0.0, float("inf"), + Rel.INC_LEFT, self.cls_name) + weight_decay_ = group_param['weight_decay'] * self.loss_scale + else: + weight_decay_ = weight_decay * self.loss_scale + + for param in group_param['params']: + if param in params_store: + raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") + params_store.append(param) + self.group_lr.append(Parameter(lr, name="lr_" + param.name)) + self.group_weight_decay.append(weight_decay_) + def get_lr(self): """ Get the learning rate of current step. @@ -151,11 +261,20 @@ class Optimizer(Cell): Returns: float, the learning rate of current step. """ - lr = self.learning_rate - if self.dynamic_lr: - lr = self.gather(self.learning_rate, self.global_step, 0) - F.control_depend(lr, self.assignadd(self.global_step, 1)) + if self.is_group: + lr = self.learning_rate + if self.dynamic_lr: + lr = () + for i in range(self.param_length): + current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) + lr += (current_dynamic_lr,) + F.control_depend(lr, self.assignadd(self.global_step, 1)) + else: + lr = self.learning_rate + if self.dynamic_lr: + lr = self.gather(self.learning_rate, self.global_step, 0) + F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr def construct(self, *hyper_params): diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index b1271587b4ac98a748338ee9b14a25c20cd9714a..b96d9499b2a6599092caafe21c641972a05c35a4 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -22,17 +22,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") -@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): +@rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad): """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" success = True success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) return success -@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", +@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): +def _centered_rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad): """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" success = True success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) @@ -44,6 +44,13 @@ class RMSProp(Optimizer): Implements Root Mean Squared Propagation (RMSProp) algorithm. Note: + The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different + `learning_rate` and `weight_decay`. + + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be + applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + Update `params` according to the RMSProp algorithm. The equation is as follows: @@ -84,8 +91,18 @@ class RMSProp(Optimizer): represents `gradients`. Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` - should be class mindspore.Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr" and "weight_decay" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -95,15 +112,13 @@ class RMSProp(Optimizer): Other cases are not supported. Default: 0.1. decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or - greater than 0.Default: 0.0. + greater than 0. Default: 0.0. epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: 1e-10. use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False. loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0. - decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: - lambda x: 'beta' not in x.name and 'gamma' not in x.name. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -113,14 +128,25 @@ class RMSProp(Optimizer): Examples: >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, + >>> {'params': no_conv_params}] + >>> opt = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 + >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a + >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() - >>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) - >>> model = Model(net, loss, opt) + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, - use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) + use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0): + super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale) validator.check_value_type("decay", decay, [float], self.cls_name) validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_value_type("momentum", momentum, [float], self.cls_name) @@ -150,9 +176,18 @@ class RMSProp(Optimizer): gradients = self.scale_grad(gradients) lr = self.get_lr() if self.centered: - success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, - self.momentum), params, self.mg, self.ms, self.moment, gradients) + if self.is_group: + success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) + else: + success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) + else: - success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, - self.momentum), params, self.ms, self.moment, gradients) + if self.is_group: + success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum), lr, params, self.ms, self.moment, gradients) + else: + success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, + self.momentum, lr), params, self.ms, self.moment, gradients) return success diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 388fe5db4780a146cbcae6012cc46ca10ff50bd6..0db58af8556210e60da296368e2c5ab055dbd690 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -24,7 +24,7 @@ sgd_opt = C.MultitypeFuncGraph("sgd_opt") @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): +def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat): """Apply sgd optimizer to the weight parameter using Tensor.""" success = True success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) @@ -39,9 +39,27 @@ class SGD(Optimizer): Nesterov momentum is based on the formula from paper `On the importance of initialization and momentum in deep learning `_. + Note: + The SGD optimizer supports separating parameter groups. Different parameter groups can set different + `learning_rate` and `weight_decay`. + + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be + applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + Args: - params (list[Parameter]): A list of parameter, which will be updated. The element in `params` - should be class mindspore.Parameter. + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr" and "weight_decay" are the keys can be parsed. + + - params: Required. The value should be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -67,9 +85,21 @@ class SGD(Optimizer): Examples: >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> #1) All parameters use the same learning rate and weight decay >>> optim = nn.SGD(params=net.trainable_params()) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, + >>> {'params': no_conv_params}] + >>> opt = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 + >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a + >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) """ def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, loss_scale=1.0): @@ -109,5 +139,8 @@ class SGD(Optimizer): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) + if self.is_group: + success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) + else: + success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) return success diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 60718ec2b112ecde40f8fa6306643d9bbd1b5818..499d85b34b976938b207666c049231476373ae97 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -167,7 +167,7 @@ class TrainOneStepCell(Cell): super(TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network self.network.add_flags(defer_inline=True) - self.weights = ParameterTuple(network.trainable_params()) + self.weights = optimizer.parameters self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.sens = sens diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index d9321b1d2667c6dac871c5612d7228d19233835f..269f276376e6d17a7828e9807366c1dbc64dd573 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -50,7 +50,7 @@ class NetWithoutWeight(nn.Cell): def test_adamwithoutparam(): net = NetWithoutWeight() net.set_train() - with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"): + with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): AdamWeightDecay(net.trainable_params(), learning_rate=0.1) @@ -104,5 +104,5 @@ def test_AdamWeightDecayDynamicLR(): def test_adam_mindspore_flatten(): net = nn.Flatten() - with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"): + with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): AdamWeightDecay(net.get_parameters()) diff --git a/tests/ut/python/nn/optim/test_optimizer.py b/tests/ut/python/nn/optim/test_optimizer.py index 89fb1d812b413691033b909e806da007d418cf2e..9f1ec9a36f0be81afbbe4539ab07e6afb1245238 100644 --- a/tests/ut/python/nn/optim/test_optimizer.py +++ b/tests/ut/python/nn/optim/test_optimizer.py @@ -69,19 +69,19 @@ class TestSGD(): class TestNullParam(): """ TestNullParam definition """ def test_optim_init(self): - with pytest.raises(TypeError): + with pytest.raises(ValueError): Optimizer(0.1, None) def test_AdamWightDecay_init(self): - with pytest.raises(TypeError): + with pytest.raises(ValueError): AdamWeightDecay(None) def test_AdamWeightDecayDynamicLR_init(self): - with pytest.raises(TypeError): + with pytest.raises(ValueError): AdamWeightDecayDynamicLR(None, 10) def test_Sgd_init(self): - with pytest.raises(TypeError): + with pytest.raises(ValueError): SGD(None) class TestUnsupportParam(): diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd98990fa8188e60557c9270fa96639ad51a1de --- /dev/null +++ b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py @@ -0,0 +1,210 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.nn.optim import Momentum, SGD, RMSProp, Adam +from mindspore import context +from mindspore.common.api import _executor +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from mindspore.nn import TrainOneStepCell, WithLossCell + +context.set_context(mode=context.GRAPH_MODE) + + +class LeNet5(nn.Cell): + """ LeNet5 definition """ + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.fc1 = nn.Dense(16 * 5 * 5, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = P.Flatten() + + def construct(self, x): + x = self.max_pool2d(self.relu(self.conv1(x))) + x = self.max_pool2d(self.relu(self.conv2(x))) + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def test_group_lr(): + inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([1, 10]).astype(np.float32)) + + net = LeNet5() + conv_lr = 0.8 + default_lr = 0.1 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params}] + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + + opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) + assert opt.is_group is True + assert opt.dynamic_lr is False + for lr, param in zip(opt.learning_rate, opt.parameters): + if param in conv_params: + assert lr.data == Tensor(conv_lr, mstype.float32) + else: + assert lr.data == Tensor(default_lr, mstype.float32) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, opt) + _executor.compile(train_network, inputs, label) + + +def test_group_dynamic_1(): + inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([1, 10]).astype(np.float32)) + + net = LeNet5() + conv_lr = 0.8 + default_lr = (0.1, 0.2, 0.3) + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params}] + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + + opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) + assert opt.is_group is True + assert opt.dynamic_lr is True + for lr, param in zip(opt.learning_rate, opt.parameters): + if param in conv_params: + assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32)) + else: + assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32)) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, opt) + _executor.compile(train_network, inputs, label) + + +def test_group_dynamic_2(): + inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([1, 10]).astype(np.float32)) + + net = LeNet5() + conv_lr = (0.1, 0.2, 0.3) + default_lr = 0.8 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params}] + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + + opt = RMSProp(group_params, learning_rate=default_lr) + assert opt.is_group is True + assert opt.dynamic_lr is True + for lr, param in zip(opt.learning_rate, opt.parameters): + if param in conv_params: + assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)) + else: + assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, opt) + _executor.compile(train_network, inputs, label) + + +def test_group_dynamic_no_same_size(): + net = LeNet5() + conv_lr = (0.1, 0.2, 0.3) + default_lr = (0.1, 0.2) + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params}] + with pytest.raises(ValueError): + Momentum(group_params, learning_rate=default_lr, momentum=0.9) + + +def test_group_not_float_lr(): + net = LeNet5() + conv_lr = 1 + default_lr = 0.3 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params}] + with pytest.raises(TypeError): + Momentum(group_params, learning_rate=default_lr, momentum=0.9) + + +def test_group_not_float_weight_decay(): + net = LeNet5() + conv_weight_decay = 1 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, + {'params': no_conv_params}] + with pytest.raises(TypeError): + Momentum(group_params, learning_rate=0.1, momentum=0.9) + + +def test_weight_decay(): + inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([1, 10]).astype(np.float32)) + + net = LeNet5() + conv_weight_decay = 0.8 + default_weight_decay = 0.0 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, + {'params': no_conv_params}] + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + + opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) + assert opt.is_group is True + for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): + if param in conv_params: + assert weight_decay == conv_weight_decay + assert decay_flags is True + else: + assert weight_decay == default_weight_decay + assert decay_flags is False + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, opt) + _executor.compile(train_network, inputs, label) + + +def test_group_repeat_param(): + net = LeNet5() + conv_lr = 0.1 + default_lr = 0.3 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': conv_params, 'lr': default_lr}, + {'params': no_conv_params}] + with pytest.raises(RuntimeError): + Adam(group_params, learning_rate=default_lr)