network_oft.py 8.2 KB
Newer Older
V
v0xie 已提交
1 2
import torch
import network
3
from einops import rearrange
4
from modules import devices
V
v0xie 已提交
5 6 7 8


class ModuleTypeOFT(network.ModuleType):
    def create_module(self, net: network.Network, weights: network.NetworkWeights):
9
        if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
V
v0xie 已提交
10 11 12 13
            return NetworkModuleOFT(net, weights)

        return None

14
# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
V
v0xie 已提交
15 16
class NetworkModuleOFT(network.NetworkModule):
    def __init__(self,  net: network.Network, weights: network.NetworkWeights):
V
v0xie 已提交
17

V
v0xie 已提交
18 19
        super().__init__(net, weights)

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
        self.lin_module = None
        # kohya-ss
        if "oft_blocks" in weights.w.keys():
            self.is_kohya = True
            self.oft_blocks = weights.w["oft_blocks"]
            self.alpha = weights.w["alpha"]
            self.dim = self.oft_blocks.shape[0]
        elif "oft_diag" in weights.w.keys():
            self.is_kohya = False
            self.oft_blocks = weights.w["oft_diag"]
            # alpha is rank if alpha is 0 or None
            if self.alpha is None:
                pass
            self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
        else:
            raise ValueError("oft_blocks or oft_diag must be in weights dict")

        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
        is_conv = type(self.sd_module) in [torch.nn.Conv2d]
        is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
        #if "Linear" in self.sd_module.__class__.__name__ or is_linear:
        if is_linear:
V
v0xie 已提交
42
            self.out_dim = self.sd_module.out_features
43 44 45 46 47 48 49
            #elif hasattr(self.sd_module, "embed_dim"):
            #    self.out_dim = self.sd_module.embed_dim
            #else:
            #    raise ValueError("Linear sd_module must have out_features or embed_dim")
        elif is_other_linear:
            self.out_dim = self.sd_module.embed_dim
        elif is_conv:
V
v0xie 已提交
50
            self.out_dim = self.sd_module.out_channels
51 52 53
        else:
            raise ValueError("sd_module must be Linear or Conv")

V
v0xie 已提交
54

55 56 57 58 59 60 61 62
        if self.is_kohya:
            self.num_blocks = self.dim
            self.block_size = self.out_dim // self.num_blocks
            self.constraint = self.alpha * self.out_dim
        #elif is_linear or is_conv:
        else:
            self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
            self.constraint = None
V
v0xie 已提交
63

V
v0xie 已提交
64
        self.org_module: list[torch.Module] = [self.sd_module]
V
v0xie 已提交
65

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
        # if is_other_linear:
        #     weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
        #     module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
        #     with torch.no_grad():
        #         if weight.shape != module.weight.shape:
        #             weight = weight.reshape(module.weight.shape)
        #         module.weight.copy_(weight)
        #     module.to(device=devices.cpu, dtype=devices.dtype)
        #     module.weight.requires_grad_(False)
        #     self.lin_module = module
            #return module

    def merge_weight(self, R_weight, org_weight):
        R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
        if org_weight.dim() == 4:
            weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight)
        else:
            weight = torch.einsum("oi, op -> pi", org_weight, R_weight)
        #weight = torch.einsum(
        #    "k n m, k n ... -> k m ...", 
        #    self.oft_diag * scale + torch.eye(self.block_size, device=device), 
        #    org_weight
        #)
        return weight
V
v0xie 已提交
90

91
    def get_weight(self, oft_blocks, multiplier=None):
92 93
        if self.constraint is not None:
            constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
V
v0xie 已提交
94

95 96 97 98 99 100 101 102 103
        block_Q = oft_blocks - oft_blocks.transpose(1, 2)
        norm_Q = torch.norm(block_Q.flatten())
        if self.constraint is not None:
            new_norm_Q = torch.clamp(norm_Q, max=constraint)
        else:
            new_norm_Q = norm_Q
        block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
        m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
        block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse())
V
v0xie 已提交
104

105 106 107 108
        block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I
        R = torch.block_diag(*block_R_weighted)
        return R
        #return self.oft_blocks
V
v0xie 已提交
109 110 111


    def calc_updown(self, orig_weight):
112
        multiplier = self.multiplier() * self.calc_scale()
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        R = self.get_weight(self.oft_blocks, multiplier)
        #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
        merged_weight = self.merge_weight(R, orig_weight)

        #if self.lin_module is not None:
        #    R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
        #    weight = torch.mul(torch.mul(R, multiplier), orig_weight)
        #else:
        #    orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
        #    weight = torch.einsum(
        #        'k n m, k n ... -> k m ...',
        #        R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
        #        orig_weight
        #    )
        #    weight = rearrange(weight, 'k m ... -> (k m) ...')

        updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
        #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
131
        output_shape = orig_weight.shape
V
v0xie 已提交
132
        orig_weight = orig_weight
V
v0xie 已提交
133

V
v0xie 已提交
134
        return self.finalize_updown(updown, orig_weight, output_shape)
V
v0xie 已提交
135

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    # override to remove the multiplier/scale factor; it's already multiplied in get_weight
    def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
        #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias)

        if self.bias is not None:
            updown = updown.reshape(self.bias.shape)
            updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
            updown = updown.reshape(output_shape)

        if len(output_shape) == 4:
            updown = updown.reshape(output_shape)

        if orig_weight.size().numel() == updown.size().numel():
            updown = updown.reshape(orig_weight.shape)

        if ex_bias is not None:
            ex_bias = ex_bias * self.multiplier()

        return updown, ex_bias
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
    
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
    '''
    return a tuple of two value of input dimension decomposed by the number closest to factor
    second value is higher or equal than first value.
    
    In LoRA with Kroneckor Product, first value is a value for weight scale.
    secon value is a value for weight.
    
    Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
    
    examples)
    factor
        -1               2                4               8               16               ...
    127 -> 1, 127   127 -> 1, 127    127 -> 1, 127   127 -> 1, 127   127 -> 1, 127
    128 -> 8, 16    128 -> 2, 64     128 -> 4, 32    128 -> 8, 16    128 -> 8, 16
    250 -> 10, 25   250 -> 2, 125    250 -> 2, 125   250 -> 5, 50    250 -> 10, 25
    360 -> 8, 45    360 -> 2, 180    360 -> 4, 90    360 -> 8, 45    360 -> 12, 30
    512 -> 16, 32   512 -> 2, 256    512 -> 4, 128   512 -> 8, 64    512 -> 16, 32
    1024 -> 32, 32  1024 -> 2, 512   1024 -> 4, 256  1024 -> 8, 128  1024 -> 16, 64
    '''
    
    if factor > 0 and (dimension % factor) == 0:
        m = factor
        n = dimension // factor
        if m > n:
            n, m = m, n
        return m, n
    if factor < 0:
        factor = dimension
    m, n = 1, dimension
    length = m + n
    while m<n:
        new_m = m + 1
        while dimension%new_m != 0:
            new_m += 1
        new_n = dimension // new_m
        if new_m + new_n > length or new_m>factor:
            break
        else:
            m, n = new_m, new_n
    if m > n:
        n, m = m, n
    return m, n