提交 1f3ce2f7 编写于 作者: J jrzaurin

fixed a number of bugs related to reading the children's parameters. Also...

fixed a number of bugs related to reading the children's parameters. Also replace the use of .apply to simply calling the function
上级 384c817c
import torch
import re
import warnings
from torch import nn
from fnmatch import fnmatch
from .wdtypes import *
import pdb
class Initializer(object):
def __call__(self, model:TorchModel):
......@@ -15,23 +18,31 @@ class Initializer(object):
class MultipleInitializers(object):
def __init__(self, initializers:Dict[str, Initializer]):
self._initializers = initializers
def __init__(self, initializers:Dict[str, Initializer], verbose=True):
self.verbose=verbose
instantiated_initializers = {}
for model_name, initializer in initializers.items():
if isinstance(initializer, type):
instantiated_initializers[model_name] = initializer()
else: instantiated_initializers[model_name] = initializer
self._initializers = instantiated_initializers
def apply(self, model:TorchModel):
children = list(model.children())
for child in children:
model_name = child.__class__.__name__.lower()
children_names = [child.__class__.__name__.lower() for child in children]
if not all([cn in children_names for cn in self._initializers.keys()]):
raise ValueError('Model name has to be one of: {}'.format(children_names))
for child, name in zip(children, children_names):
try:
child.apply(self._initializers[model_name])
except KeyError:
raise ValueError(
'Model name has to be one of: {}'.format(str([child.__class__.__name__.lower()
for child in children])))
self._initializers[name](child)
except:
if self.verbose: warnings.warn("No initializer found for {}".format(name))
class Normal(Initializer):
def __init__(self, mean=0.0, std=0.02, bias=False, pattern='*'):
def __init__(self, mean=0.0, std=1.0, bias=False, pattern='.'):
self.mean = mean
self.std = std
self.bias = bias
......@@ -40,18 +51,18 @@ class Normal(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if re.search(self.pattern, n):
if self.bias and ('bias' in n):
nn.init.normal_(p, mean=self.mean, std=self.std)
elif 'bias' in n:
continue
pass
elif p.requires_grad:
nn.init.normal_(p, mean=self.mean, std=self.std)
class Uniform(Initializer):
def __init__(self, a=0, b=1, bias=False, pattern='*'):
def __init__(self, a=0, b=1, bias=False, pattern='.'):
self.a = a
self.b = b
self.bias = bias
......@@ -60,101 +71,124 @@ class Uniform(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if re.search(self.pattern, n):
if self.bias and ('bias' in n):
nn.init.uniform_(p, a=self.a, b=self.b)
elif 'bias' in n:
continue
pass
elif p.requires_grad:
nn.init.uniform_(p, a=self.a, b=self.b)
class ConstantInitializer(Initializer):
def __init__(self, value, bias=False, pattern='*'):
def __init__(self, value, bias=False, pattern='.'):
self.bias = bias
self.value = value
self.pattern = pattern
super(ConstantInitializer, self).__init__()
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if re.search(self.pattern, n):
if self.bias and ('bias' in n):
nn.init.constant_(p, val=self.value)
elif ('bias' in n):
continue
pass
elif p.requires_grad:
p.requires_grad: nn.init.constant_(p, val=self.value)
nn.init.constant_(p, val=self.value)
class XavierUniform(Initializer):
def __init__(self, gain=1, pattern='*'):
def __init__(self, gain=1, pattern='.'):
self.gain = gain
self.pattern = pattern
super(XavierUniform, self).__init__()
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.xavier_uniform_(p, gain=self.gain)
if re.search(self.pattern, n):
if 'bias' in n:
nn.init.constant_(p, val=0)
elif p.requires_grad:
try:
nn.init.xavier_uniform_(p, gain=self.gain)
except: pass
class XavierNormal(Initializer):
def __init__(self, gain=1, pattern='*'):
def __init__(self, gain=1, pattern='.'):
self.gain = gain
self.pattern = pattern
super(XavierNormal, self).__init__()
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.xavier_normal_(p, gain=self.gain)
if re.search(self.pattern, n):
if 'bias' in n:
nn.init.constant_(p, val=0)
elif p.requires_grad:
try:
nn.init.xavier_normal_(p, gain=self.gain)
except: pass
class KaimingUniform(Initializer):
def __init__(self, a=0, mode='fan_in', pattern='*'):
self.a = a
self.mode = mode
self.pattern = pattern
super(KaimingUniform, self).__init__()
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu', pattern='.'):
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.pattern = pattern
super(KaimingUniform, self).__init__()
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.kaiming_uniform_(p, a=self.a, mode=self.mode)
if re.search(self.pattern, n):
if 'bias' in n:
nn.init.constant_(p, val=0)
elif p.requires_grad:
try:
nn.init.kaiming_normal_(p, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
except: pass
class KaimingNormal(Initializer):
def __init__(self, a=0, mode='fan_in', pattern='*'):
self.a = a
self.mode = mode
self.pattern = pattern
super(KaimingNormal, self).__init__()
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu', pattern='.'):
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.pattern = pattern
super(KaimingNormal, self).__init__()
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.kaiming_normal_(p, a=self.a, mode=self.mode)
if re.search(self.pattern, n):
if 'bias' in n:
nn.init.constant_(p, val=0)
elif p.requires_grad:
try:
nn.init.kaiming_normal_(p, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
except: pass
class Orthogonal(Initializer):
def __init__(self, gain=1, pattern='*'):
def __init__(self, gain=1, pattern='.'):
self.gain = gain
self.pattern = pattern
super(Orthogonal, self).__init__()
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.normal_(p, val=0)
elif p.requires_grad: nn.init.orthogonal_(p, gain=self.gain)
\ No newline at end of file
if re.search(self.pattern, n):
if 'bias' in n:
nn.init.constant_(p, val=0)
elif p.requires_grad:
try:
nn.init.orthogonal_(p, gain=self.gain)
except: pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册