提交 85fcccc1 编写于 作者: A AngelBottomless 提交者: aria1th

Squashed commit of fixing dropout silently

fix dropouts for future hypernetworks

add kwargs for Hypernetwork class

hypernet UI for gradio input

add recommended options

remove as options

revert adding options in ui
上级 b6a8bb12
......@@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
......@@ -60,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
if use_dropout and i < len(layer_structure) - 3:
if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
......@@ -126,7 +127,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
......@@ -139,11 +140,14 @@ class Hypernetwork:
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
def weights(self):
......@@ -172,7 +176,8 @@ class Hypernetwork:
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout
torch.save(state_dict, filename)
def load(self, filename):
......@@ -193,12 +198,16 @@ class Hypernetwork:
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)
......
......@@ -1238,8 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Normal is default, for experiments, relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册