提交 a30f3164 编写于 作者: J jrzaurin

adapting the code to the cases where the input classes are not instantiated

上级 30dc5289
......@@ -7,7 +7,7 @@ from .wdtypes import *
class Initializer(object):
def __call__(self, model:nn.Module):
def __call__(self, model:TorchModel):
raise NotImplementedError('Initializer must implement this method')
......@@ -16,7 +16,7 @@ class MultipleInitializers(object):
def __init__(self, initializers:Dict[str, Initializer]):
self._initializers = initializers
def apply(self, model:nn.Module):
def apply(self, model:TorchModel):
children = list(model.children())
for child in children:
model_name = child.__class__.__name__.lower()
......@@ -35,7 +35,7 @@ class Normal(Initializer):
self.pattern = pattern
super(Normal, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if self.bias and ('bias' in n):
......@@ -55,7 +55,7 @@ class Uniform(Initializer):
self.pattern = pattern
super(Uniform, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if self.bias and ('bias' in n):
......@@ -74,7 +74,7 @@ class ConstantInitializer(Initializer):
self.pattern = pattern
super(ConstantInitializer, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if self.bias and ('bias' in n):
......@@ -92,7 +92,7 @@ class XavierUniform(Initializer):
self.pattern = pattern
super(XavierUniform, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if 'bias' in n: nn.init.constant_(p, val=0)
......@@ -106,7 +106,7 @@ class XavierNormal(Initializer):
self.pattern = pattern
super(XavierNormal, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if 'bias' in n: nn.init.constant_(p, val=0)
......@@ -121,7 +121,7 @@ class KaimingUniform(Initializer):
self.pattern = pattern
super(KaimingUniform, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if 'bias' in n: nn.init.constant_(p, val=0)
......@@ -136,7 +136,7 @@ class KaimingNormal(Initializer):
self.pattern = pattern
super(KaimingNormal, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if 'bias' in n: nn.init.constant_(p, val=0)
......@@ -150,7 +150,7 @@ class Orthogonal(Initializer):
self.pattern = pattern
super(Orthogonal, self).__init__()
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if 'bias' in n: nn.init.normal_(p, val=0)
......
......@@ -7,12 +7,16 @@ from .wdtypes import *
class MultipleLRScheduler(object):
def __init__(self,schedulers:Dict[str,LRScheduler]):
self._schedulers = schedulers
instantiated_schedulers = {}
for model_name, scheduler in schedulers.items():
if isinstance(scheduler, type):
instantiated_schedulers[model_name] = scheduler()
else: instantiated_schedulers[model_name] = scheduler
self._schedulers = instantiated_schedulers
def apply(self, optimizers:Dict[str, Optimizer]):
for model_name, optimizer in optimizers.items():
if isinstance(self._schedulers[model_name], type):
self._schedulers[model_name] = self._schedulers[model_name]()
self._schedulers[model_name] = self._schedulers[model_name](optimizer)
def step(self, loss=None):
......
......@@ -4,20 +4,22 @@ from torch import nn
from .radam import RAdam as orgRAdam
from .wdtypes import *
import pdb
class MultipleOptimizers(object):
def __init__(self, optimizers:Dict[str,Optimizer]):
self._optimizers = optimizers
def apply(self, model:nn.Module):
instantiated_optimizers = {}
for model_name, optimizer in optimizers.items():
if isinstance(optimizer, type):
instantiated_optimizers[model_name] = optimizer()
else: instantiated_optimizers[model_name] = optimizer
self._optimizers = instantiated_optimizers
def apply(self, model:TorchModel):
children = list(model.children())
for child in children:
model_name = child.__class__.__name__.lower()
if isinstance(self._optimizers[model_name], type):
self._optimizers[model_name] = self._optimizers[model_name]()
self._optimizers[model_name] = self._optimizers[model_name](child)
def zero_grad(self):
......@@ -39,7 +41,7 @@ class Adam:
self.weight_decay=weight_decay
self.amsgrad=amsgrad
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel) -> Optimizer:
self.opt = torch.optim.Adam(submodel.parameters(), lr=self.lr, betas=self.betas, eps=self.eps,
weight_decay=self.weight_decay, amsgrad=self.amsgrad)
return self.opt
......@@ -59,7 +61,7 @@ class RAdam:
self.eps=eps
self.weight_decay=weight_decay
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel) -> Optimizer:
self.opt = orgRAdam(submodel.parameters(), lr=self.lr, betas=self.betas, eps=self.eps,
weight_decay=self.weight_decay)
return self.opt
......@@ -81,7 +83,7 @@ class SGD:
self.weight_decay=weight_decay
self.nesterov=nesterov
def __call__(self, submodel:nn.Module):
def __call__(self, submodel:TorchModel) -> Optimizer:
self.opt = torch.optim.SGD(submodel.parameters(), lr=self.lr, momentum=self.momentum,
dampening=self.dampening, weight_decay=self.weight_decay, nesterov=self.nesterov)
return self.opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册