提交 9b9dcde7 编写于 作者: J jrzaurin

added a warning related to the model name (needs to be a error)

上级 00eb5fbc
import torch
import warnings
from torch import nn
from fnmatch import fnmatch
from .wdtypes import *
class Initializer(object):
def __call__(self, model:TorchModel):
......@@ -23,12 +25,13 @@ class MultipleInitializers(object):
try:
child.apply(self._initializers[model_name])
except KeyError:
raise Exception('Model name has to be one of: {}'.format(str([child.__class__.__name__.lower() for child in children])))
warnings.warn(
'Model name has to be one of: {}'.format(str([child.__class__.__name__.lower()
for child in children])), ValueError)
class Normal(Initializer):
def __init__(self, mean=0.0, std=0.02, bias=False, pattern='*', pattern_in=True):
def __init__(self, mean=0.0, std=0.02, bias=False, pattern='*'):
self.mean = mean
self.std = std
self.bias = bias
......@@ -37,10 +40,10 @@ class Normal(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if self.bias and ('bias' in n):
nn.init.normal_(p, mean=self.mean, std=self.std)
elif ('bias' in n):
elif 'bias' in n:
continue
elif p.requires_grad:
nn.init.normal_(p, mean=self.mean, std=self.std)
......@@ -48,7 +51,7 @@ class Normal(Initializer):
class Uniform(Initializer):
def __init__(self, a=0, b=1, bias=False, pattern='*', pattern_in=True):
def __init__(self, a=0, b=1, bias=False, pattern='*'):
self.a = a
self.b = b
self.bias = bias
......@@ -57,10 +60,10 @@ class Uniform(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if self.bias and ('bias' in n):
nn.init.uniform_(p, a=self.a, b=self.b)
elif ('bias' in n):
elif 'bias' in n:
continue
elif p.requires_grad:
nn.init.uniform_(p, a=self.a, b=self.b)
......@@ -76,12 +79,12 @@ class ConstantInitializer(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if self.bias and ('bias' in n):
nn.init.constant_(p, val=self.value)
elif ('bias' in n):
continue
elif:
elif p.requires_grad:
p.requires_grad: nn.init.constant_(p, val=self.value)
......@@ -94,7 +97,7 @@ class XavierUniform(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.xavier_uniform_(p, gain=self.gain)
......@@ -108,7 +111,7 @@ class XavierNormal(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.xavier_normal_(p, gain=self.gain)
......@@ -123,7 +126,7 @@ class KaimingUniform(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.kaiming_uniform_(p, a=self.a, mode=self.mode)
......@@ -138,7 +141,7 @@ class KaimingNormal(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.constant_(p, val=0)
elif p.requires_grad: nn.init.kaiming_normal_(p, a=self.a, mode=self.mode)
......@@ -152,6 +155,6 @@ class Orthogonal(Initializer):
def __call__(self, submodel:TorchModel):
for n,p in submodel.named_parameters():
if fnmatch(n, self.pattern) and pattern_in:
if fnmatch(n, self.pattern):
if 'bias' in n: nn.init.normal_(p, val=0)
elif p.requires_grad: nn.init.orthogonal_(p, gain=self.gain)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册