提交 a049eb83 编写于 作者: J jrzaurin

modified the code to default to Adam when multiple optimizers are used but...

modified the code to default to Adam when multiple optimizers are used but some of them are not specified
上级 4407074d
import torch
import warnings
from torch import nn
from .radam import RAdam as orgRAdam
......@@ -18,14 +19,17 @@ class MultipleOptimizers(object):
def apply(self, model:TorchModel):
children = list(model.children())
for child in children:
model_name = child.__class__.__name__.lower()
children_names = [child.__class__.__name__.lower() for child in children]
if not all([cn in children_names for cn in self._optimizers.keys()]):
raise ValueError('Model name has to be one of: {}'.format(children_names))
for child, name in zip(children, children_names):
try:
self._optimizers[model_name] = self._optimizers[model_name](child)
except KeyError:
raise ValueError(
'Model name has to be one of: {}'.format(str([child.__class__.__name__.lower()
for child in children])))
self._optimizers[name] = self._optimizers[name](child)
except:
warnings.warn(
"No optimizer found for {}. Adam optimizer with default "
"settings will be used".format(name))
self._optimizers[name] = Adam()(child)
def zero_grad(self):
for _, opt in self._optimizers.items():
......@@ -97,7 +101,7 @@ class RMSprop:
def __call__(self, submodel:TorchModel) -> Optimizer:
self.opt = torch.optim.SGD(submodel.parameters(), lr = self.lr, alpha = self.alpha,
self.opt = torch.optim.RMSprop(submodel.parameters(), lr = self.lr, alpha = self.alpha,
eps = self.eps, weight_decay = self.weight_decay, momentum = self.momentum,
centered = self.centered)
return self.opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册