提交 ce0c3e3d 编写于 作者: J jrzaurin

adapted the code to include the possibility of using different learning rates for different groups

上级 195bc274
......@@ -17,15 +17,20 @@ class MultipleOptimizers(object):
else: instantiated_optimizers[model_name] = optimizer
self._optimizers = instantiated_optimizers
def apply(self, model:TorchModel):
def apply(self, model:TorchModel, param_group=None):
children = list(model.children())
children_names = [child.__class__.__name__.lower() for child in children]
if not all([cn in children_names for cn in self._optimizers.keys()]):
raise ValueError('Model name has to be one of: {}'.format(children_names))
for child, name in zip(children, children_names):
try:
if name in self._optimizers and name in param_group:
self._optimizers[name] = self._optimizers[name](child, param_group[name])
elif name in self._optimizers:
self._optimizers[name] = self._optimizers[name](child)
except:
else:
warnings.warn(
"No optimizer found for {}. Adam optimizer with default "
"settings will be used".format(name))
......@@ -51,8 +56,10 @@ class Adam:
self.weight_decay=weight_decay
self.amsgrad=amsgrad
def __call__(self, submodel:TorchModel) -> Optimizer:
self.opt = torch.optim.Adam(submodel.parameters(), lr=self.lr, betas=self.betas, eps=self.eps,
def __call__(self, submodel:TorchModel, param_group=None) -> Optimizer:
if param_group is not None: params = param_group
else: params = submodel.parameters()
self.opt = torch.optim.Adam(params, lr=self.lr, betas=self.betas, eps=self.eps,
weight_decay=self.weight_decay, amsgrad=self.amsgrad)
return self.opt
......@@ -66,7 +73,9 @@ class RAdam:
self.eps=eps
self.weight_decay=weight_decay
def __call__(self, submodel:TorchModel) -> Optimizer:
def __call__(self, submodel:TorchModel, param_group=None) -> Optimizer:
if param_group is not None: params = param_group
else: params = submodel.parameters()
self.opt = orgRAdam(submodel.parameters(), lr=self.lr, betas=self.betas, eps=self.eps,
weight_decay=self.weight_decay)
return self.opt
......@@ -82,7 +91,9 @@ class SGD:
self.weight_decay=weight_decay
self.nesterov=nesterov
def __call__(self, submodel:TorchModel) -> Optimizer:
def __call__(self, submodel:TorchModel, param_group=None) -> Optimizer:
if param_group is not None: params = param_group
else: params = submodel.parameters()
self.opt = torch.optim.SGD(submodel.parameters(), lr=self.lr, momentum=self.momentum,
dampening=self.dampening, weight_decay=self.weight_decay, nesterov=self.nesterov)
return self.opt
......@@ -99,8 +110,9 @@ class RMSprop:
self.momentum = momentum
self.centered = centered
def __call__(self, submodel:TorchModel) -> Optimizer:
def __call__(self, submodel:TorchModel, param_group=None) -> Optimizer:
if param_group is not None: params = param_group
else: params = submodel.parameters()
self.opt = torch.optim.RMSprop(submodel.parameters(), lr = self.lr, alpha = self.alpha,
eps = self.eps, weight_decay = self.weight_decay, momentum = self.momentum,
centered = self.centered)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册