提交 03a1e288 编写于 作者: A AUTOMATIC

turns out LayerNorm also has weight and bias and needs to be pre-multiplied...

turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets
上级 e4877722
......@@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
if type(layer) == torch.nn.Linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
......@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
if type(layer) == torch.nn.Linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册