提交 3e7dd971 编写于 作者: J jrzaurin

renaming parameters for consistency

上级 aff6008d
...@@ -18,7 +18,7 @@ def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, maxpool:bool=True, ...@@ -18,7 +18,7 @@ def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, maxpool:bool=True,
class DeepImage(nn.Module): class DeepImage(nn.Module):
def __init__(self, output_dim:int=1, pretrained:bool=True, resnet=18, def __init__(self, output_dim, pretrained:bool=True, resnet=18,
freeze:Union[str,int]=6): freeze:Union[str,int]=6):
super(DeepImage, self).__init__() super(DeepImage, self).__init__()
""" """
...@@ -53,7 +53,7 @@ class DeepImage(nn.Module): ...@@ -53,7 +53,7 @@ class DeepImage(nn.Module):
frozen_layers.append(layer) frozen_layers.append(layer)
self.backbone = nn.Sequential(*frozen_layers) self.backbone = nn.Sequential(*frozen_layers)
if isinstance(freeze, int): if isinstance(freeze, int):
assert freeze < 8 assert freeze < 8, 'freeze must be less than 8 when using resnet architectures'
frozen_layers = [] frozen_layers = []
trainable_layers = backbone_layers[freeze:] trainable_layers = backbone_layers[freeze:]
for layer in backbone_layers[:freeze]: for layer in backbone_layers[:freeze]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册