提交 613b0d95 编写于 作者: V v0xie

doc: add boft comment

上级 325eaeb5
......@@ -29,13 +29,14 @@ class NetworkModuleOFT(network.NetworkModule):
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
# LyCORIS OFT
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
# LyCORIS BOFT
self.is_boft = False
if weights.w["oft_diag"].dim() == 4:
self.is_boft = True
......@@ -89,6 +90,7 @@ class NetworkModuleOFT(network.NetworkModule):
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
else:
# TODO: determine correct value for scale
scale = 1.0
m = self.boft_m
b = self.boft_b
......@@ -99,8 +101,6 @@ class NetworkModuleOFT(network.NetworkModule):
if i == 0:
# Apply multiplier/scale and rescale into first weight
bi = bi * scale + (1 - scale) * eye
#if self.rescaled:
# bi = bi * self.rescale
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册