network_oft.py 9.2 KB
Newer Older
V
v0xie 已提交
1 2
import torch
import network
3
from lyco_helpers import factorization
4
from einops import rearrange
5
from modules import devices
V
v0xie 已提交
6 7 8 9


class ModuleTypeOFT(network.ModuleType):
    def create_module(self, net: network.Network, weights: network.NetworkWeights):
10
        if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
V
v0xie 已提交
11 12 13 14
            return NetworkModuleOFT(net, weights)

        return None

15 16
# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
V
v0xie 已提交
17 18
class NetworkModuleOFT(network.NetworkModule):
    def __init__(self,  net: network.Network, weights: network.NetworkWeights):
V
v0xie 已提交
19

V
v0xie 已提交
20 21
        super().__init__(net, weights)

22
        self.lin_module = None
V
v0xie 已提交
23
        self.org_module: list[torch.Module] = [self.sd_module]
24

25 26 27
        # kohya-ss
        if "oft_blocks" in weights.w.keys():
            self.is_kohya = True
28
            self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
29
            self.alpha = weights.w["alpha"]
30 31
            self.dim = self.oft_blocks.shape[0] # lora dim
            #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...')
32 33
        elif "oft_diag" in weights.w.keys():
            self.is_kohya = False
34 35
            self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size)

36 37 38
            # alpha is rank if alpha is 0 or None
            if self.alpha is None:
                pass
V
v0xie 已提交
39
            self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
40 41 42 43 44
        else:
            raise ValueError("oft_blocks or oft_diag must be in weights dict")

        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
        is_conv = type(self.sd_module) in [torch.nn.Conv2d]
45 46
        is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]

47
        if is_linear:
V
v0xie 已提交
48
            self.out_dim = self.sd_module.out_features
49 50 51
        elif is_other_linear:
            self.out_dim = self.sd_module.embed_dim
        elif is_conv:
V
v0xie 已提交
52
            self.out_dim = self.sd_module.out_channels
53 54 55 56
        else:
            raise ValueError("sd_module must be Linear or Conv")

        if self.is_kohya:
57 58 59 60
            #self.num_blocks = self.dim
            #self.block_size = self.out_dim // self.num_blocks
            #self.block_size = self.dim
            #self.num_blocks = self.out_dim // self.block_size
61
            self.constraint = self.alpha * self.out_dim
62
            self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
63 64
        else:
            self.constraint = None
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)

        if is_other_linear:
            self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True)


    def create_module(self, weights, key, none_ok=False):
        weight = weights.get(key)

        if weight is None and none_ok:
            return None

        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
        is_conv = type(self.sd_module) in [torch.nn.Conv2d]

        if is_linear:
            weight = weight.reshape(weight.shape[0], -1)
            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
        elif is_conv and key == "lora_down.weight" or key == "dyn_up":
            if len(weight.shape) == 2:
                weight = weight.reshape(weight.shape[0], -1, 1, 1)

            if weight.shape[2] != 1 or weight.shape[3] != 1:
                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
            else:
                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
        elif is_conv and key == "lora_mid.weight":
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
        elif is_conv and key == "lora_up.weight" or key == "dyn_down":
            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
        else:
            raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')

        with torch.no_grad():
            if weight.shape != module.weight.shape:
                weight = weight.reshape(module.weight.shape)
            module.weight.copy_(weight)

        module.to(device=devices.cpu, dtype=devices.dtype)
        module.weight.requires_grad_(False)

        return module

V
v0xie 已提交
108

109 110 111 112 113 114 115
    def merge_weight(self, R_weight, org_weight):
        R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
        if org_weight.dim() == 4:
            weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight)
        else:
            weight = torch.einsum("oi, op -> pi", org_weight, R_weight)
        return weight
V
v0xie 已提交
116

117
    def get_weight(self, oft_blocks, multiplier=None):
118 119
        if self.constraint is not None:
            constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
V
v0xie 已提交
120

121 122 123 124 125 126 127
        block_Q = oft_blocks - oft_blocks.transpose(1, 2)
        norm_Q = torch.norm(block_Q.flatten())
        if self.constraint is not None:
            new_norm_Q = torch.clamp(norm_Q, max=constraint)
        else:
            new_norm_Q = norm_Q
        block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
128 129
        m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1)
        #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
130
        block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse())
V
v0xie 已提交
131

132 133 134
        block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I
        R = torch.block_diag(*block_R_weighted)
        return R
V
v0xie 已提交
135

136 137 138
    def calc_updown_kohya(self, orig_weight, multiplier):
        R = self.get_weight(self.oft_blocks, multiplier)
        merged_weight = self.merge_weight(R, orig_weight)
V
v0xie 已提交
139

140 141 142 143 144 145 146 147 148
        updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
        output_shape = orig_weight.shape
        orig_weight = orig_weight
        return self.finalize_updown(updown, orig_weight, output_shape)

    def calc_updown_kb(self, orig_weight, multiplier):
        is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]

        if not is_other_linear:
149 150 151 152 153 154 155 156 157 158
            #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
            #    orig_weight=orig_weight.permute(1, 0)

            oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)

            # without this line the results are significantly worse / less accurate
            oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)

            R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
            R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
159

V
v0xie 已提交
160 161 162
            merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
            merged_weight = torch.einsum(
                'k n m, k n ... -> k m ...',
163
                R,
164
                merged_weight
V
v0xie 已提交
165 166
            )
            merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
167

168 169
            #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
            #    orig_weight=orig_weight.permute(1, 0)
170

V
v0xie 已提交
171 172 173
            updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
            output_shape = orig_weight.shape
        else:
174
            # FIXME: skip MultiheadAttention for now
175
            #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
V
v0xie 已提交
176 177
            updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
            output_shape = (orig_weight.shape[1], orig_weight.shape[1])
178

V
v0xie 已提交
179
        return self.finalize_updown(updown, orig_weight, output_shape)
V
v0xie 已提交
180

181 182
    def calc_updown(self, orig_weight):
        multiplier = self.multiplier() * self.calc_scale()
183 184 185 186
        #if self.is_kohya:
        #    return self.calc_updown_kohya(orig_weight, multiplier)
        #else:
        return self.calc_updown_kb(orig_weight, multiplier)
187

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
    # override to remove the multiplier/scale factor; it's already multiplied in get_weight
    def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
        #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias)

        if self.bias is not None:
            updown = updown.reshape(self.bias.shape)
            updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
            updown = updown.reshape(output_shape)

        if len(output_shape) == 4:
            updown = updown.reshape(output_shape)

        if orig_weight.size().numel() == updown.size().numel():
            updown = updown.reshape(orig_weight.shape)

        if ex_bias is not None:
            ex_bias = ex_bias * self.multiplier()

        return updown, ex_bias