Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
stoneliu1981
pytorch-image-models
提交
11ae795e
P
pytorch-image-models
项目概览
stoneliu1981
/
pytorch-image-models
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-image-models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
11ae795e
编写于
5月 25, 2021
作者:
R
Ross Wightman
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel
上级
d400f1db
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
37 addition
and
12 deletion
+37
-12
timm/models/levit.py
timm/models/levit.py
+37
-12
未找到文件。
timm/models/levit.py
浏览文件 @
11ae795e
...
...
@@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman
import
itertools
from
copy
import
deepcopy
from
functools
import
partial
from
typing
import
Dict
import
torch
import
torch.nn
as
nn
...
...
@@ -255,6 +256,8 @@ class Subsample(nn.Module):
class
Attention
(
nn
.
Module
):
ab
:
Dict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
dim
,
key_dim
,
num_heads
=
8
,
attn_ratio
=
4
,
act_layer
=
None
,
resolution
=
14
,
use_conv
=
False
):
super
().
__init__
()
...
...
@@ -286,20 +289,31 @@ class Attention(nn.Module):
idxs
.
append
(
attention_offsets
[
offset
])
self
.
attention_biases
=
nn
.
Parameter
(
torch
.
zeros
(
num_heads
,
len
(
attention_offsets
)))
self
.
register_buffer
(
'attention_bias_idxs'
,
torch
.
LongTensor
(
idxs
).
view
(
N
,
N
))
self
.
ab
=
None
self
.
ab
=
{}
@
torch
.
no_grad
()
def
train
(
self
,
mode
=
True
):
super
().
train
(
mode
)
self
.
ab
=
None
if
mode
else
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
if
mode
and
self
.
ab
:
self
.
ab
=
{}
# clear ab cache
def
get_attention_biases
(
self
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
if
self
.
training
:
return
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
else
:
device_key
=
str
(
device
)
if
device_key
not
in
self
.
ab
:
self
.
ab
[
device_key
]
=
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
return
self
.
ab
[
device_key
]
def
forward
(
self
,
x
):
# x (B,C,H,W)
if
self
.
use_conv
:
B
,
C
,
H
,
W
=
x
.
shape
q
,
k
,
v
=
self
.
qkv
(
x
).
view
(
B
,
self
.
num_heads
,
-
1
,
H
*
W
).
split
([
self
.
key_dim
,
self
.
key_dim
,
self
.
d
],
dim
=
2
)
ab
=
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
if
self
.
ab
is
None
else
self
.
ab
attn
=
(
q
.
transpose
(
-
2
,
-
1
)
@
k
)
*
self
.
scale
+
ab
attn
=
(
q
.
transpose
(
-
2
,
-
1
)
@
k
)
*
self
.
scale
+
self
.
get_attention_biases
(
x
.
device
)
attn
=
attn
.
softmax
(
dim
=-
1
)
x
=
(
v
@
attn
.
transpose
(
-
2
,
-
1
)).
view
(
B
,
-
1
,
H
,
W
)
else
:
B
,
N
,
C
=
x
.
shape
...
...
@@ -308,15 +322,18 @@ class Attention(nn.Module):
q
=
q
.
permute
(
0
,
2
,
1
,
3
)
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
v
=
v
.
permute
(
0
,
2
,
1
,
3
)
ab
=
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
if
self
.
ab
is
None
else
self
.
ab
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
*
self
.
scale
+
ab
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
*
self
.
scale
+
self
.
get_attention_biases
(
x
.
device
)
attn
=
attn
.
softmax
(
dim
=-
1
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
self
.
dh
)
x
=
self
.
proj
(
x
)
return
x
class
AttentionSubsample
(
nn
.
Module
):
ab
:
Dict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
in_dim
,
out_dim
,
key_dim
,
num_heads
=
8
,
attn_ratio
=
2
,
act_layer
=
None
,
stride
=
2
,
resolution
=
14
,
resolution_
=
7
,
use_conv
=
False
):
...
...
@@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module):
idxs
.
append
(
attention_offsets
[
offset
])
self
.
attention_biases
=
nn
.
Parameter
(
torch
.
zeros
(
num_heads
,
len
(
attention_offsets
)))
self
.
register_buffer
(
'attention_bias_idxs'
,
torch
.
LongTensor
(
idxs
).
view
(
N_
,
N
))
self
.
ab
=
Non
e
self
.
ab
=
{}
# per-device attention_biases cach
e
@
torch
.
no_grad
()
def
train
(
self
,
mode
=
True
):
super
().
train
(
mode
)
self
.
ab
=
None
if
mode
else
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
if
mode
and
self
.
ab
:
self
.
ab
=
{}
# clear ab cache
def
get_attention_biases
(
self
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
if
self
.
training
:
return
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
else
:
device_key
=
str
(
device
)
if
device_key
not
in
self
.
ab
:
self
.
ab
[
device_key
]
=
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
return
self
.
ab
[
device_key
]
def
forward
(
self
,
x
):
if
self
.
use_conv
:
...
...
@@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module):
k
,
v
=
self
.
kv
(
x
).
view
(
B
,
self
.
num_heads
,
-
1
,
H
*
W
).
split
([
self
.
key_dim
,
self
.
d
],
dim
=
2
)
q
=
self
.
q
(
x
).
view
(
B
,
self
.
num_heads
,
self
.
key_dim
,
self
.
resolution_2
)
ab
=
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
if
self
.
ab
is
None
else
self
.
ab
attn
=
(
q
.
transpose
(
-
2
,
-
1
)
@
k
)
*
self
.
scale
+
ab
attn
=
(
q
.
transpose
(
-
2
,
-
1
)
@
k
)
*
self
.
scale
+
self
.
get_attention_biases
(
x
.
device
)
attn
=
attn
.
softmax
(
dim
=-
1
)
x
=
(
v
@
attn
.
transpose
(
-
2
,
-
1
)).
reshape
(
B
,
-
1
,
self
.
resolution_
,
self
.
resolution_
)
...
...
@@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module):
v
=
v
.
permute
(
0
,
2
,
1
,
3
)
# BHNC
q
=
self
.
q
(
x
).
view
(
B
,
self
.
resolution_2
,
self
.
num_heads
,
self
.
key_dim
).
permute
(
0
,
2
,
1
,
3
)
ab
=
self
.
attention_biases
[:,
self
.
attention_bias_idxs
]
if
self
.
ab
is
None
else
self
.
ab
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
*
self
.
scale
+
ab
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
*
self
.
scale
+
self
.
get_attention_biases
(
x
.
device
)
attn
=
attn
.
softmax
(
dim
=-
1
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
-
1
,
self
.
dh
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录