提交 c95215bc 编写于 作者: G guohongzilong

seperate lr groups and weight_decay groups

上级 3c4c0da8
...@@ -243,7 +243,7 @@ class Adam(Optimizer): ...@@ -243,7 +243,7 @@ class Adam(Optimizer):
self.beta1_power = beta1_power self.beta1_power = beta1_power
beta2_power = self.beta2_power * self.beta2 beta2_power = self.beta2_power * self.beta2
self.beta2_power = beta2_power self.beta2_power = beta2_power
if self.is_group: if self.is_group_lr:
success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1,
self.beta2, self.eps), self.beta2, self.eps),
lr, gradients, params, moment1, moment2) lr, gradients, params, moment1, moment2)
......
...@@ -111,7 +111,7 @@ class Momentum(Optimizer): ...@@ -111,7 +111,7 @@ class Momentum(Optimizer):
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.is_group: if self.is_group_lr:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments)
else: else:
success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
......
...@@ -94,6 +94,7 @@ class Optimizer(Cell): ...@@ -94,6 +94,7 @@ class Optimizer(Cell):
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None)
self.is_group = False self.is_group = False
self.is_group_lr = False
self.loss_scale = loss_scale self.loss_scale = loss_scale
if isinstance(learning_rate, float): if isinstance(learning_rate, float):
self.dynamic_lr = False self.dynamic_lr = False
...@@ -116,14 +117,17 @@ class Optimizer(Cell): ...@@ -116,14 +117,17 @@ class Optimizer(Cell):
self.group_weight_decay = [] self.group_weight_decay = []
self._init_group_params(parameters, learning_rate, weight_decay) self._init_group_params(parameters, learning_rate, weight_decay)
if self.is_group: if self.is_group_lr:
self.learning_rate = ParameterTuple(self.group_lr) self.learning_rate = ParameterTuple(self.group_lr)
else:
self.learning_rate = Parameter(learning_rate, name="learning_rate")
if self.is_group:
self.parameters = ParameterTuple(self.params) self.parameters = ParameterTuple(self.params)
self.weight_decay = tuple(self.group_weight_decay) self.weight_decay = tuple(self.group_weight_decay)
decay_filter = lambda x: x > 0 decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
else: else:
self.learning_rate = Parameter(learning_rate, name="learning_rate")
self.parameters = ParameterTuple(parameters) self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale self.weight_decay = weight_decay * loss_scale
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
...@@ -207,6 +211,7 @@ class Optimizer(Cell): ...@@ -207,6 +211,7 @@ class Optimizer(Cell):
for group_param in parameters: for group_param in parameters:
lr_length = dynamic_lr_length lr_length = dynamic_lr_length
if 'lr' in group_param.keys(): if 'lr' in group_param.keys():
self.is_group_lr = True
self._get_single_lr(group_param['lr']) self._get_single_lr(group_param['lr'])
if isinstance(group_param['lr'], Iterable): if isinstance(group_param['lr'], Iterable):
lr_length = len(group_param['lr']) lr_length = len(group_param['lr'])
...@@ -247,6 +252,10 @@ class Optimizer(Cell): ...@@ -247,6 +252,10 @@ class Optimizer(Cell):
else: else:
weight_decay_ = weight_decay * self.loss_scale weight_decay_ = weight_decay * self.loss_scale
for key in group_param.keys():
if key not in ('params', 'lr', 'weight_decay'):
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
for param in group_param['params']: for param in group_param['params']:
if param in params_store: if param in params_store:
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
...@@ -261,7 +270,7 @@ class Optimizer(Cell): ...@@ -261,7 +270,7 @@ class Optimizer(Cell):
Returns: Returns:
float, the learning rate of current step. float, the learning rate of current step.
""" """
if self.is_group: if self.is_group_lr:
lr = self.learning_rate lr = self.learning_rate
if self.dynamic_lr: if self.dynamic_lr:
lr = () lr = ()
......
...@@ -176,7 +176,7 @@ class RMSProp(Optimizer): ...@@ -176,7 +176,7 @@ class RMSProp(Optimizer):
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.centered: if self.centered:
if self.is_group: if self.is_group_lr:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) self.momentum), lr, params, self.mg, self.ms, self.moment, gradients)
else: else:
...@@ -184,7 +184,7 @@ class RMSProp(Optimizer): ...@@ -184,7 +184,7 @@ class RMSProp(Optimizer):
self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) self.momentum, lr), params, self.mg, self.ms, self.moment, gradients)
else: else:
if self.is_group: if self.is_group_lr:
success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon,
self.momentum), lr, params, self.ms, self.moment, gradients) self.momentum), lr, params, self.ms, self.moment, gradients)
else: else:
......
...@@ -139,7 +139,7 @@ class SGD(Optimizer): ...@@ -139,7 +139,7 @@ class SGD(Optimizer):
gradients = self.decay_weight(gradients) gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
if self.is_group: if self.is_group_lr:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
else: else:
success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat)
......
...@@ -65,12 +65,13 @@ def test_group_lr(): ...@@ -65,12 +65,13 @@ def test_group_lr():
opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9)
assert opt.is_group is True assert opt.is_group is True
assert opt.is_group_lr is True
assert opt.dynamic_lr is False assert opt.dynamic_lr is False
for lr, param in zip(opt.learning_rate, opt.parameters): for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params: if param in conv_params:
assert lr.data == Tensor(conv_lr, mstype.float32) assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy())
else: else:
assert lr.data == Tensor(default_lr, mstype.float32) assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy())
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt) train_network = TrainOneStepCell(net_with_loss, opt)
...@@ -96,9 +97,9 @@ def test_group_dynamic_1(): ...@@ -96,9 +97,9 @@ def test_group_dynamic_1():
assert opt.dynamic_lr is True assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters): for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params: if param in conv_params:
assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32)) assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy())
else: else:
assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32)) assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy())
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt) train_network = TrainOneStepCell(net_with_loss, opt)
...@@ -124,9 +125,9 @@ def test_group_dynamic_2(): ...@@ -124,9 +125,9 @@ def test_group_dynamic_2():
assert opt.dynamic_lr is True assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters): for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params: if param in conv_params:
assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)) assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)))
else: else:
assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)) assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)))
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, opt) train_network = TrainOneStepCell(net_with_loss, opt)
...@@ -184,6 +185,7 @@ def test_weight_decay(): ...@@ -184,6 +185,7 @@ def test_weight_decay():
opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay)
assert opt.is_group is True assert opt.is_group is True
assert opt.is_group_lr is False
for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters):
if param in conv_params: if param in conv_params:
assert weight_decay == conv_weight_decay assert weight_decay == conv_weight_decay
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册