提交 e70b2f54 编写于 作者: G guohongzilong

add optimizer.get_lr_parameter() method

上级 fd72534a
......@@ -257,6 +257,7 @@ class Optimizer(Cell):
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
for param in group_param['params']:
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
if param in params_store:
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
params_store.append(param)
......@@ -286,6 +287,36 @@ class Optimizer(Cell):
F.control_depend(lr, self.assignadd(self.global_step, 1))
return lr
def get_lr_parameter(self, param):
"""
Get the learning rate of parameter.
Args:
param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
Returns:
Parameter, single `Parameter` or `list[Parameter]` according to the input type.
"""
if not isinstance(param, (Parameter, list)):
raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.")
if isinstance(param, list):
lr = []
for p in param:
validator.check_value_type("parameter", p, [Parameter], self.cls_name)
if self.is_group_lr:
index = self.parameters.index(p)
lr.append(self.learning_rate[index])
else:
lr.append(self.learning_rate)
else:
if self.is_group_lr:
index = self.parameters.index(param)
lr = self.learning_rate[index]
else:
lr = self.learning_rate
return lr
def construct(self, *hyper_params):
raise NotImplementedError
......
......@@ -210,3 +210,41 @@ def test_group_repeat_param():
{'params': no_conv_params}]
with pytest.raises(RuntimeError):
Adam(group_params, learning_rate=default_lr)
def test_get_lr_parameter_with_group():
net = LeNet5()
conv_lr = 0.1
default_lr = 0.3
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'lr': conv_lr},
{'params': no_conv_params, 'lr': default_lr}]
opt = SGD(group_params)
assert opt.is_group_lr is True
for param in opt.parameters:
lr = opt.get_lr_parameter(param)
assert lr.name == 'lr_' + param.name
lr_list = opt.get_lr_parameter(conv_params)
for lr, param in zip(lr_list, conv_params):
assert lr.name == 'lr_' + param.name
def test_get_lr_parameter_with_no_group():
net = LeNet5()
conv_weight_decay = 0.8
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
{'params': no_conv_params}]
opt = SGD(group_params)
assert opt.is_group_lr is False
for param in opt.parameters:
lr = opt.get_lr_parameter(param)
assert lr.name == opt.learning_rate.name
params_error = [1, 2, 3]
with pytest.raises(TypeError):
opt.get_lr_parameter(params_error)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册