提交 160f25a9 编写于 作者: V Varuna Jayasiri

fix

上级 37b30cfc
......@@ -603,9 +603,9 @@
<span class="lineno">174</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_v</span><span class="p">(</span><span class="n">cond</span><span class="p">)</span>
<span class="lineno">175</span>
<span class="lineno">176</span> <span class="k">if</span> <span class="n">CrossAttention</span><span class="o">.</span><span class="n">use_flash_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">cond</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_head</span> <span class="o">&lt;=</span> <span class="mi">128</span><span class="p">:</span>
<span class="lineno">177</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="lineno">177</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">flash_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="lineno">178</span> <span class="k">else</span><span class="p">:</span>
<span class="lineno">179</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
<span class="lineno">179</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normal_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
......
......@@ -174,9 +174,9 @@ class CrossAttention(nn.Module):
v = self.to_v(cond)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
self.flash_attention(q, k, v)
return self.flash_attention(q, k, v)
else:
self.normal_attention(q, k, v)
return self.normal_attention(q, k, v)
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册