Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
stoneliu1981
pytorch-image-models
提交
79760198
P
pytorch-image-models
项目概览
stoneliu1981
/
pytorch-image-models
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-image-models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
79760198
编写于
5月 20, 2021
作者:
A
Alexander Soare
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
extend positional embedding resizing functionality to tnt
上级
8086943b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
23 addition
and
8 deletion
+23
-8
timm/models/tnt.py
timm/models/tnt.py
+23
-8
未找到文件。
timm/models/tnt.py
浏览文件 @
79760198
...
...
@@ -14,7 +14,9 @@ from functools import partial
from
timm.data
import
IMAGENET_DEFAULT_MEAN
,
IMAGENET_DEFAULT_STD
from
timm.models.helpers
import
load_pretrained
from
timm.models.layers
import
Mlp
,
DropPath
,
trunc_normal_
from
timm.models.layers.helpers
import
to_2tuple
from
timm.models.registry
import
register_model
from
timm.models.vision_transformer
import
resize_pos_embed
def
_cfg
(
url
=
''
,
**
kwargs
):
...
...
@@ -118,11 +120,15 @@ class PixelEmbed(nn.Module):
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
in_dim
=
48
,
stride
=
4
):
super
().
__init__
()
num_patches
=
(
img_size
//
patch_size
)
**
2
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
# grid_size property necessary for resizing positional embedding
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
num_patches
=
(
self
.
grid_size
[
0
])
*
(
self
.
grid_size
[
1
])
self
.
img_size
=
img_size
self
.
num_patches
=
num_patches
self
.
in_dim
=
in_dim
new_patch_size
=
math
.
ceil
(
patch_size
/
stride
)
new_patch_size
=
[
math
.
ceil
(
ps
/
stride
)
for
ps
in
patch_size
]
self
.
new_patch_size
=
new_patch_size
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
self
.
in_dim
,
kernel_size
=
7
,
padding
=
3
,
stride
=
stride
)
...
...
@@ -130,11 +136,11 @@ class PixelEmbed(nn.Module):
def
forward
(
self
,
x
,
pixel_pos
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
and
W
==
self
.
img_size
,
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
}
*
{
self
.
img_size
}
)."
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
]
,
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
x
=
self
.
unfold
(
x
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
*
self
.
num_patches
,
self
.
in_dim
,
self
.
new_patch_size
,
self
.
new_patch_size
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
*
self
.
num_patches
,
self
.
in_dim
,
self
.
new_patch_size
[
0
],
self
.
new_patch_size
[
1
]
)
x
=
x
+
pixel_pos
x
=
x
.
reshape
(
B
*
self
.
num_patches
,
self
.
in_dim
,
-
1
).
transpose
(
1
,
2
)
return
x
...
...
@@ -155,7 +161,7 @@ class TNT(nn.Module):
num_patches
=
self
.
pixel_embed
.
num_patches
self
.
num_patches
=
num_patches
new_patch_size
=
self
.
pixel_embed
.
new_patch_size
num_pixel
=
new_patch_size
**
2
num_pixel
=
new_patch_size
[
0
]
*
new_patch_size
[
1
]
self
.
norm1_proj
=
norm_layer
(
num_pixel
*
in_dim
)
self
.
proj
=
nn
.
Linear
(
num_pixel
*
in_dim
,
embed_dim
)
...
...
@@ -163,7 +169,7 @@ class TNT(nn.Module):
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
patch_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
))
self
.
pixel_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
in_dim
,
new_patch_size
,
new_patch_size
))
self
.
pixel_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
in_dim
,
new_patch_size
[
0
],
new_patch_size
[
1
]
))
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
...
...
@@ -224,6 +230,14 @@ class TNT(nn.Module):
return
x
def
checkpoint_filter_fn
(
state_dict
,
model
):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if
state_dict
[
'patch_pos'
].
shape
!=
model
.
patch_pos
.
shape
:
state_dict
[
'patch_pos'
]
=
resize_pos_embed
(
state_dict
[
'patch_pos'
],
model
.
patch_pos
,
getattr
(
model
,
'num_tokens'
,
1
),
model
.
pixel_embed
.
grid_size
)
return
state_dict
@
register_model
def
tnt_s_patch16_224
(
pretrained
=
False
,
**
kwargs
):
model
=
TNT
(
patch_size
=
16
,
embed_dim
=
384
,
in_dim
=
24
,
depth
=
12
,
num_heads
=
6
,
in_num_head
=
4
,
...
...
@@ -231,7 +245,8 @@ def tnt_s_patch16_224(pretrained=False, **kwargs):
model
.
default_cfg
=
default_cfgs
[
'tnt_s_patch16_224'
]
if
pretrained
:
load_pretrained
(
model
,
num_classes
=
model
.
num_classes
,
in_chans
=
kwargs
.
get
(
'in_chans'
,
3
))
model
,
num_classes
=
model
.
num_classes
,
in_chans
=
kwargs
.
get
(
'in_chans'
,
3
),
filter_fn
=
checkpoint_filter_fn
)
return
model
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录