diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 6f7f60a2166a7f23f35e21e3b7c6bdd0ae2354d2..d931e5a52f34b62a5d5cc1d4e2da860abfe5d048 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -257,6 +257,7 @@ class Optimizer(Cell): logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") for param in group_param['params']: + validator.check_value_type("parameter", param, [Parameter], self.cls_name) if param in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") params_store.append(param) @@ -286,6 +287,36 @@ class Optimizer(Cell): F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr + def get_lr_parameter(self, param): + """ + Get the learning rate of parameter. + + Args: + param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`. + + Returns: + Parameter, single `Parameter` or `list[Parameter]` according to the input type. + """ + if not isinstance(param, (Parameter, list)): + raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.") + + if isinstance(param, list): + lr = [] + for p in param: + validator.check_value_type("parameter", p, [Parameter], self.cls_name) + if self.is_group_lr: + index = self.parameters.index(p) + lr.append(self.learning_rate[index]) + else: + lr.append(self.learning_rate) + else: + if self.is_group_lr: + index = self.parameters.index(param) + lr = self.learning_rate[index] + else: + lr = self.learning_rate + return lr + def construct(self, *hyper_params): raise NotImplementedError diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py index 24ee9254a99136ac367b706c3ba2b7b82ddb6ea5..675582048840ab5fe4ddf06be6faa886878a8c0d 100644 --- a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py @@ -210,3 +210,41 @@ def test_group_repeat_param(): {'params': no_conv_params}] with pytest.raises(RuntimeError): Adam(group_params, learning_rate=default_lr) + + +def test_get_lr_parameter_with_group(): + net = LeNet5() + conv_lr = 0.1 + default_lr = 0.3 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params, 'lr': default_lr}] + opt = SGD(group_params) + assert opt.is_group_lr is True + for param in opt.parameters: + lr = opt.get_lr_parameter(param) + assert lr.name == 'lr_' + param.name + + lr_list = opt.get_lr_parameter(conv_params) + for lr, param in zip(lr_list, conv_params): + assert lr.name == 'lr_' + param.name + + +def test_get_lr_parameter_with_no_group(): + net = LeNet5() + conv_weight_decay = 0.8 + + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, + {'params': no_conv_params}] + opt = SGD(group_params) + assert opt.is_group_lr is False + for param in opt.parameters: + lr = opt.get_lr_parameter(param) + assert lr.name == opt.learning_rate.name + + params_error = [1, 2, 3] + with pytest.raises(TypeError): + opt.get_lr_parameter(params_error)