Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
27395ac8
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
27395ac8
编写于
12月 08, 2022
作者:
W
wangxinxin08
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add owl-vit code
上级
45ba40b4
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
1413 addition
and
0 deletion
+1413
-0
ppdet/modeling/vl/__init__.py
ppdet/modeling/vl/__init__.py
+15
-0
ppdet/modeling/vl/embedder/__init__.py
ppdet/modeling/vl/embedder/__init__.py
+61
-0
ppdet/modeling/vl/embedder/clip/__init__.py
ppdet/modeling/vl/embedder/clip/__init__.py
+17
-0
ppdet/modeling/vl/embedder/clip/clip.py
ppdet/modeling/vl/embedder/clip/clip.py
+98
-0
ppdet/modeling/vl/embedder/clip/layers.py
ppdet/modeling/vl/embedder/clip/layers.py
+204
-0
ppdet/modeling/vl/embedder/clip/models.py
ppdet/modeling/vl/embedder/clip/models.py
+207
-0
ppdet/modeling/vl/head/__init__.py
ppdet/modeling/vl/head/__init__.py
+13
-0
ppdet/modeling/vl/head/owl_vit_head.py
ppdet/modeling/vl/head/owl_vit_head.py
+201
-0
ppdet/modeling/vl/loss/__init__.py
ppdet/modeling/vl/loss/__init__.py
+13
-0
ppdet/modeling/vl/loss/owl_vit_loss.py
ppdet/modeling/vl/loss/owl_vit_loss.py
+139
-0
ppdet/modeling/vl/matcher/__init__.py
ppdet/modeling/vl/matcher/__init__.py
+15
-0
ppdet/modeling/vl/models/__init__.py
ppdet/modeling/vl/models/__init__.py
+15
-0
ppdet/modeling/vl/models/owl_vit.py
ppdet/modeling/vl/models/owl_vit.py
+87
-0
ppdet/modeling/vl/tokenizer/__init__.py
ppdet/modeling/vl/tokenizer/__init__.py
+1
-0
ppdet/modeling/vl/tokenizer/simple_tokenizer.py
ppdet/modeling/vl/tokenizer/simple_tokenizer.py
+180
-0
ppdet/modeling/vl/utils/__init__.py
ppdet/modeling/vl/utils/__init__.py
+15
-0
ppdet/modeling/vl/utils/utils.py
ppdet/modeling/vl/utils/utils.py
+132
-0
未找到文件。
ppdet/modeling/vl/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.models
import
OWLViT
\ No newline at end of file
ppdet/modeling/vl/embedder/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
__all__
=
[
'ClipImageTextEmbedder'
]
@
register
class
ClipImageTextEmbedder
(
nn
.
Layer
):
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
def
__init__
(
self
,
base_model
,
embed_dim
,
merge_class_token
=
'drop'
):
super
().
__init__
()
self
.
clip
=
base_model
self
.
merge_class_token
=
merge_class_token
if
self
.
merge_class_token
==
'mul-ln'
:
self
.
merged_class_token
=
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
images
,
texts
):
if
texts
is
not
None
:
texts_shape
=
texts
.
shape
if
len
(
texts_shape
)
>
2
:
texts
=
texts
.
reshape
(
-
1
,
texts_shape
[
-
1
])
if
images
is
not
None
:
images
=
normalize_image
(
images
)
img_emb
,
txt_emb
=
self
.
clip
(
images
,
texts
,
normalize
=
False
)
if
img_emb
is
not
None
:
if
self
.
merge_class_token
==
'drop'
:
img_emb
=
img_emb
[:,
1
:,
:]
elif
self
.
merge_class_token
==
'mul-ln'
:
img_emb
=
img_emb
[:,
:
1
,
:]
*
img_emb
[:,
1
:,
:]
img_emb
=
self
.
merged_class_token
(
img_emb
)
else
:
raise
ValueError
(
f
'Unknown merge_class_token:
{
self
.
merge_class_token
}
'
)
if
txt_emb
is
not
None
and
len
(
texts_shape
)
>
2
:
txt_emb
=
txt_emb
.
reshape
(
texts_shape
[:
-
1
]
+
[
-
1
,
])
return
img_emb
,
txt_emb
ppdet/modeling/vl/embedder/clip/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.models
import
ModifiedResNet
,
TextEncoder
,
VisionTransformer
from
.layers
import
LayerNorm
,
QuickGELU
,
AttentionPool2D
from
.clip
import
CLIP
ppdet/modeling/vl/embedder/clip/clip.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
Normal
,
Constant
from
ppdet.modeling.layers
import
MultiHeadAttention
from
ppdet.modeling.initializer
import
zeros_
,
normal_
from
ppdet.core.workspace
import
register
from
.models
import
ModifiedResNet
,
VisionTransformer
,
TextEncoder
@
register
class
CLIP
(
nn
.
Layer
):
__inject__
=
[
'image_encoder'
,
'text_encoder'
]
def
__init__
(
self
,
image_encoder
,
text_encoder
):
super
().
__init__
()
self
.
visual
=
image_encoder
self
.
text
=
text_encoder
self
.
initialize_parameters
()
def
initialize_parameters
(
self
):
if
isinstance
(
self
.
visual
,
ModifiedResNet
):
if
self
.
visual
.
attnpool
is
not
None
:
std
=
self
.
visual
.
attnpool
.
c_proj
.
weight
.
shape
[
0
]
**-
0.5
normal_
(
self
.
visual
.
attnpool
.
q_proj
.
weight
,
std
=
std
)
normal_
(
self
.
visual
.
attnpool
.
k_proj
.
weight
,
std
=
std
)
normal_
(
self
.
visual
.
attnpool
.
v_proj
.
weight
,
std
=
std
)
normal_
(
self
.
visual
.
attnpool
.
c_proj
.
weight
,
std
=
std
)
for
resnet_block
in
[
self
.
visual
.
layer1
,
self
.
visual
.
layer2
,
self
.
visual
.
layer3
,
self
.
visual
.
layer4
]:
for
name
,
param
in
resnet_block
.
named_parameters
():
if
name
.
endswith
(
"bn3.weight"
):
zeros_
(
param
)
normal_
(
self
.
text
.
token_embedding
.
weight
,
std
=
0.02
)
normal_
(
self
.
text
.
positional_embedding
,
std
=
0.01
)
proj_std
=
(
self
.
text
.
transformer
.
width
**-
0.5
)
*
(
(
2
*
self
.
text
.
transformer
.
layers
)
**-
0.5
)
attn_std
=
self
.
text
.
transformer
.
width
**-
0.5
fc_std
=
(
2
*
self
.
text
.
transformer
.
width
)
**-
0.5
for
block
in
self
.
text
.
transformer
.
resblocks
:
normal_
(
block
.
attn
.
in_proj_weight
,
std
=
attn_std
)
normal_
(
block
.
attn
.
out_proj
.
weight
,
std
=
proj_std
)
normal_
(
block
.
mlp
.
c_fc
.
weight
,
std
=
fc_std
)
normal_
(
block
.
mlp
.
c_proj
.
weight
,
std
=
proj_std
)
if
self
.
text
.
text_projection
is
not
None
:
normal_
(
self
.
text
.
text_projection
.
weight
,
std
=
self
.
text
.
transformer
.
width
**-
0.5
)
@
property
def
dtype
(
self
):
return
self
.
visual
.
conv1
.
weight
.
dtype
def
encode_image
(
self
,
image
):
return
self
.
visual
(
image
.
cast
(
self
.
dtype
))
def
encode_text
(
self
,
text
):
return
self
.
text
(
text
.
cast
(
self
.
dtype
))
def
forward
(
self
,
image
,
text
,
normalize
=
True
):
image_features
=
self
.
encode_image
(
image
)
text_features
=
self
.
encode_text
(
text
)
if
normalize
:
image_features
/=
image_features
.
norm
(
axis
=
1
,
keepdim
=
True
)
text_features
/=
image_features
.
norm
(
axis
=
1
,
keepdim
=
True
)
return
image_fetaures
,
text_features
ppdet/modeling/vl/embedder/clip/layers.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
Normal
,
Constant
from
ppdet.modeling.layers
import
MultiHeadAttention
from
ppdet.modeling.initializer
import
zeros_
,
normal_
# ResNet
class
Bottleneck
(
nn
.
Layer
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
):
super
().
__init__
()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self
.
conv1
=
nn
.
Conv2D
(
inplanes
,
planes
,
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu1
=
nn
.
ReLU
()
self
.
conv2
=
nn
.
Conv2D
(
planes
,
planes
,
3
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu2
=
nn
.
ReLU
()
self
.
avgpool
=
nn
.
AvgPool2D
(
stride
)
if
stride
>
1
else
nn
.
Identity
()
self
.
conv3
=
nn
.
Conv2D
(
planes
,
planes
*
self
.
expansion
,
1
,
bias_attr
=
False
)
self
.
bn3
=
nn
.
BatchNorm2D
(
planes
*
self
.
expansion
)
self
.
relu3
=
nn
.
ReLU
()
self
.
downsample
=
None
self
.
stride
=
stride
if
stride
>
1
or
inplanes
!=
planes
*
Bottleneck
.
expansion
:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self
.
downsample
=
nn
.
Sequential
(
OrderedDict
([(
"-1"
,
nn
.
AvgPool2D
(
stride
)),
(
"0"
,
nn
.
Conv2D
(
inplanes
,
planes
*
self
.
expansion
,
1
,
stride
=
1
,
bias_attr
=
False
)),
(
"1"
,
nn
.
BatchNorm2D
(
planes
*
self
.
expansion
))]))
def
forward
(
self
,
x
):
dentity
=
x
out
=
self
.
relu1
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
relu2
(
self
.
bn2
(
self
.
conv2
(
out
)))
out
=
self
.
avgpool
(
out
)
out
=
self
.
bn3
(
self
.
conv3
(
out
))
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu3
(
out
)
return
out
class
AttentionPool2D
(
nn
.
Module
):
def
__init__
(
self
,
spacial_dim
,
embed_dim
,
num_heads
,
output_dim
):
super
().
__init__
()
# TODO: need check whether it is consistent with torch or not
self
.
positional_embedding
=
self
.
create_parameter
(
shape
=
[
spacial_dim
**
2
+
1
,
embed_dim
],
attr
=
ParamAttr
(
initializer
=
Normal
(
std
=
1.
/
embed_dim
**
0.5
)))
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
c_proj
=
nn
.
Linear
(
embed_dim
,
output_dim
or
embed_dim
)
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
def
forward
(
self
,
x
):
# [N, C, H, W] -> [N, C, HW] -> [N, HW, C]
x
=
x
.
flatten
(
start_axis
=
2
).
transpose
([
0
,
2
,
1
])
# [N, 1, C] + [N, HW, C] = [N, HW+1, C]
x
=
paddle
.
concat
([
x
.
mean
(
axis
=
1
,
keepdim
=
True
),
x
],
axis
=
1
)
# [N, HW+1, C]
x
=
x
+
self
.
positional_embedding
.
unsqueeze
(
0
)
# compute q, k, v
q
=
self
.
q_proj
(
x
[:,
:
1
,
:])
k
=
self
.
k_proj
(
x
)
v
=
self
.
v_proj
(
x
)
# [N, 1, C] -> [N, 1, num_heads, head_dim] -> [N, num_heads, 1, head_dim]
q
=
q
.
reshape
([
0
,
0
,
self
.
num_heads
,
self
.
head_dim
]).
transpose
(
[
0
,
2
,
1
,
3
])
# [N, HW+1, C] -> [N, HW+1, num_heads, head_dim] -> [N, num_heads, HW+1, head_dim]
k
=
k
.
reshape
([
0
,
0
,
self
.
num_heads
,
self
.
head_dim
]).
transpose
(
[
0
,
2
,
1
,
3
])
v
=
v
.
reshape
([
0
,
0
,
self
.
num_heads
,
self
.
head_dim
]).
transpose
(
[
0
,
2
,
1
,
3
])
# [N, num_heads, 1, HW+1]
product
=
paddle
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
)
scaling
=
float
(
self
.
head_dim
)
**-
0.5
product
=
product
*
scaling
weights
=
F
.
softmax
(
product
)
# [N, num_heads, 1, head_dim]
out
=
paddle
.
matmul
(
weights
,
v
)
# [N, num_heads, 1, head_dim] -> [N, 1, num_heads, head_dim] -> [N, embed_dim]
out
=
out
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
0
,
self
.
embed_dim
])
return
out
class
LayerNorm
(
nn
.
LayerNorm
):
"""Subclass torch's LayerNorm to handle fp16."""
def
forward
(
self
,
x
):
orig_type
=
x
.
dtype
ret
=
super
().
forward
(
x
.
cast
(
paddle
.
float32
))
return
ret
.
cast
(
orig_type
)
class
QuickGELU
(
nn
.
Layer
):
def
forward
(
self
,
x
):
return
x
*
F
.
sigmoid
(
1.702
*
x
)
class
ResidualAttentionBlock
(
nn
.
Layer
):
def
__init__
(
self
,
d_model
,
n_head
,
droplayer_p
=
0.0
,
attn_mask
=
None
):
super
().
__init__
()
self
.
attn
=
MultiHeadAttention
(
d_model
,
n_head
)
self
.
ln_1
=
LayerNorm
(
d_model
)
self
.
mlp
=
nn
.
Sequential
(
OrderedDict
([(
"c_fc"
,
nn
.
Linear
(
d_model
,
d_model
*
4
)),
(
"gelu"
,
QuickGELU
()),
(
"c_proj"
,
nn
.
Linear
(
d_model
*
4
,
d_model
)
)]))
self
.
ln_2
=
LayerNorm
(
d_model
)
self
.
attn_mask
=
attn_mask
self
.
droplayer_p
=
droplayer_p
def
get_drop_pattern
(
self
,
x
):
if
self
.
training
and
self
.
droplayer_p
:
shape
=
(
x
.
shape
[
0
],
)
+
(
1
,
)
*
(
len
(
x
.
shape
)
-
1
)
p
=
self
.
droplayer_p
*
paddle
.
ones
(
shape
)
return
paddle
.
bernoulli
(
p
)
else
:
return
0.0
def
attention
(
self
,
x
):
self
.
attn_mask
=
self
.
attn_mask
.
cast
(
dtype
=
x
.
dtype
)
if
self
.
attn_mask
is
not
None
else
None
return
self
.
attn
(
x
,
x
,
x
,
attn_mask
=
self
.
attn_mask
)
def
forward
(
self
,
x
):
y
=
self
.
attention
(
self
.
ln_1
(
x
))
drop_pattern
=
self
.
get_drop_pattern
(
y
)
x
=
x
+
y
*
(
1.0
-
drop_pattern
)
y
=
self
.
mlp
(
self
.
ln_2
(
x
))
drop_pattern
=
self
.
get_drop_pattern
(
y
)
x
=
x
+
y
*
(
1.0
-
drop_pattern
)
return
x
class
Transformer
(
nn
.
Layer
):
def
__init__
(
self
,
width
,
layers
,
heads
,
stochastic_droplayer_rate
=
0.0
,
attn_mask
=
None
):
super
().
__init__
()
self
.
width
=
width
self
.
layers
=
layers
blocks
=
[]
for
i
in
range
(
self
.
layers
):
droplayer_p
=
(
i
/
max
(
self
.
layers
-
1
,
1
))
*
self
.
stochastic_droplayer_rate
blocks
.
append
(
ResidualAttentionBlock
(
width
,
heads
,
droplayer_p
,
attn_mask
))
self
.
resblocks
=
nn
.
Sequential
(
*
blocks
)
def
forward
(
self
,
x
):
return
self
.
resblocks
(
x
)
ppdet/modeling/vl/embedder/clip/models.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
Normal
,
Constant
from
ppdet.modeling.initializer
import
zeros_
,
normal_
from
ppdet.core.workspace
import
register
from
.layers
import
*
__all__
=
[
'ModifiedResNet'
,
'VisionTransformer'
,
'TextEncoder'
]
@
register
class
ModifiedResNet
(
nn
.
Layer
):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def
__init__
(
self
,
layers
,
output_dim
,
heads
,
input_resolution
=
224
,
width
=
64
):
super
().
__init__
()
self
.
output_dim
=
output_dim
self
.
input_resolution
=
input_resolution
# the 3-layer stem
self
.
conv1
=
nn
.
Conv2D
(
3
,
width
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
width
//
2
)
self
.
relu1
=
nn
.
ReLU
()
self
.
conv2
=
nn
.
Conv2D
(
width
//
2
,
width
//
2
,
kernel_size
=
3
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm2D
(
width
//
2
)
self
.
relu2
=
nn
.
ReLU
()
self
.
conv3
=
nn
.
Conv2D
(
width
//
2
,
width
,
kernel_size
=
3
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn3
=
nn
.
BatchNorm2D
(
width
)
self
.
relu3
=
nn
.
ReLU
()
self
.
avgpool
=
nn
.
AvgPool2D
(
2
)
# residual layers
self
.
_inplanes
=
width
# this is a *mutable* variable used during construction
self
.
layer1
=
self
.
_make_layer
(
width
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
width
*
2
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
width
*
4
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
width
*
8
,
layers
[
3
],
stride
=
2
)
embed_dim
=
width
*
32
# the ResNet feature dimension
self
.
attnpool
=
AttentionPool2D
(
input_resolution
//
32
,
embed_dim
,
heads
,
output_dim
)
def
_make_layer
(
self
,
planes
,
blocks
,
stride
=
1
):
layers
=
[
Bottleneck
(
self
.
_inplanes
,
planes
,
stride
)]
self
.
_inplanes
=
planes
*
Bottleneck
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
Bottleneck
(
self
.
_inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
x
.
cast
(
self
.
conv1
.
weight
.
dtype
)
x
=
self
.
relu1
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
relu2
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
self
.
relu3
(
self
.
bn3
(
self
.
conv3
(
x
)))
x
=
self
.
avgpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
self
.
attnpool
(
x
)
return
x
@
register
class
VisionTransformer
(
nn
.
Layer
):
def
__init__
(
self
,
input_resolution
,
patch_size
,
width
,
layers
,
heads
,
output_dim
=
None
,
stochastic_droplayer_rate
=
0.0
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
output_dim
=
output_dim
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
3
,
out_channels
=
width
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
False
)
scale
=
width
**-
0.5
self
.
class_embedding
=
self
.
create_parameter
(
shape
=
[
width
],
attr
=
ParamAttr
(
initializer
=
Normal
(
std
=
scale
)))
self
.
positional_embedding
=
self
.
create_parameter
(
shape
=
[(
input_resolution
//
patch_size
)
**
2
+
1
,
width
],
attr
=
ParamAttr
(
initializer
=
Normal
(
std
=
scale
)))
self
.
ln_pre
=
LayerNorm
(
width
)
self
.
transformer
=
Transformer
(
width
,
layers
,
heads
,
stochastic_droplayer_rate
)
self
.
ln_post
=
LayerNorm
(
width
)
if
output_dim
is
not
None
:
self
.
proj
=
nn
.
Linear
(
self
.
width
,
self
.
output_dim
,
bias_attr
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
x
.
reshape
([
x
.
shape
[
0
],
x
.
shape
[
1
],
-
1
])
x
=
x
.
transpose
([
0
,
2
,
1
])
class_embedding
=
self
.
class_embedding
.
cast
(
x
.
dtype
)
+
paddle
.
zeros
(
[
x
.
shape
[
0
],
1
,
x
.
shape
[
-
1
]],
type
=
x
.
dtype
)
x
=
paddle
.
concat
([
class_embedding
,
x
],
axis
=
1
)
x
=
x
+
self
.
positional_embedding
.
cast
(
x
.
dtype
)
x
=
self
.
ln_pre
(
x
)
x
=
feature
=
self
.
transformer
(
x
)
if
self
.
output_dim
is
not
None
:
x
=
self
.
ln_post
(
x
[:,
0
,
:])
x
=
self
.
proj
(
x
)
else
:
x
=
self
.
ln_post
(
x
)
return
x
,
feature
@
register
class
TextEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
context_length
,
vocab_size
,
transformer_width
,
transformer_heads
,
transformer_layers
,
stochastic_droplayer_rate
):
super
().
__init__
()
self
.
context_length
=
context_length
self
.
transformer
=
Transformer
(
width
=
transformer_width
,
layers
=
transformer_layers
,
heads
=
transformer_heads
,
stochastic_droplayer_rate
=
stochastic_droplayer_rate
,
attn_mask
=
self
.
build_attention_mask
())
self
.
vocab_size
=
vocab_size
self
.
token_embedding
=
nn
.
Embedding
(
vocab_size
,
transformer_width
)
self
.
positional_embedding
=
self
.
create_parameter
(
shape
=
[
transformer_width
,
embed_dim
],
attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
)))
self
.
ln_final
=
LayerNorm
(
transformer_width
)
self
.
text_projection
=
nn
.
Linear
(
transformer_width
,
embed_dim
,
bias_attr
=
False
)
self
.
logit_scale
=
self
.
create_parameter
(
shape
=
[],
attr
=
ParamAttr
(
initializer
=
Constant
(
np
.
log
(
1.
/
0.07
))))
def
build_attention_mask
(
self
):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask
=
paddle
.
full
((
self
.
context_length
,
self
.
context_length
),
float
(
"-inf"
))
mask
=
paddle
.
triu
(
mask
)
return
mask
def
forward
(
self
,
text
):
x
=
self
.
token_embedding
(
text
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
positional_embedding
.
cast
(
x
.
dtype
)
x
=
self
.
transformer
(
x
)
x
=
self
.
ln_final
(
x
).
cast
(
x
.
dtype
)
# x.shape = [batch_size, text_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
batch_idx
=
paddle
.
arange
(
x
.
shape
(
0
))
seq_idx
=
text
.
argmax
(
dim
=-
1
)
gather_idx
=
paddle
.
stack
([
batch_idx
,
seq_idx
],
axis
=
1
)
x
=
paddle
.
gather_nd
(
x
,
gather_idx
)
x
=
self
.
text_projection
(
x
)
return
x
ppdet/modeling/vl/head/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ppdet/modeling/vl/head/owl_vit_head.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.modeling.ops
import
get_act_fn
from
..utils
import
compute_box_bias
__all__
=
[
'PredictorMLP'
,
'ClassPredictor'
,
'OWLViTHead'
]
@
register
class
PredictorMLP
(
nn
.
Layer
):
"""FFN block for predicting continuous outputs, e.g. bounding box coordinates.
Attributes:
out_dim: Size of output of this mlp.
num_layers: Number of layers.
mlp_dim: Size of hidden dimension of dense layers.
hidden_activation: Activation function of hidden layers.
out_activation: Activation of the output.
dtype: Data type, e.g. jnp.float32.
"""
def
__init__
(
self
,
in_dim
,
out_dim
,
num_layers
,
mlp_dim
,
hidden_activation
,
out_activation
=
None
):
super
().
__init__
()
layers
=
[]
for
_
in
range
(
num_layers
-
1
):
layers
.
append
(
nn
.
Linear
(
in_dim
,
mlp_dim
))
in_dim
=
mlp_dim
layers
.
append
(
nn
.
Linear
(
in_dim
,
out_dim
))
self
.
mlp
=
nn
.
LayerList
(
layers
)
self
.
num_layers
=
num_layers
self
.
hidden_activation
=
get_act_fn
(
hidden_activation
)
self
.
out_activation
=
get_act_fn
(
out_activation
)
def
forward
(
self
,
inputs
):
x
=
inputs
for
_
in
range
(
self
.
num_layers
-
1
):
x
=
self
.
mlp
[
i
](
x
)
x
=
self
.
hidden_activation
(
x
)
x
=
self
.
mlp
[
-
1
](
x
)
x
=
self
.
out_activation
(
x
)
return
x
@
register
class
ClassPredictor
(
nn
.
Layer
):
"""Open-vocabulary instance class predictor."""
def
__init__
(
self
,
in_dim
,
out_dim
,
normalize
):
super
().
__init__
()
self
.
normalize
=
normalize
self
.
out_dim
=
out_dim
self
.
proj
=
nn
.
Linear
(
in_dim
,
out_dim
)
self
.
logit_shift
=
nn
.
Linear
(
in_dim
,
1
)
self
.
logit_scale
=
nn
.
Linear
(
in_dim
,
1
)
def
forward
(
self
,
x
,
query_embeddings
=
None
,
query_mask
=
None
):
"""Computes class prediction logits.
Query embeddings from a text encoder define the classification label space.
Args:
x: Image features [batch_size, num_patches, emb_dim].
query_embeddings: The embeddings to classify against of shape [batch_size,
num_queries, out_dim]. If not specified, only the image class embeddings
will be returned.
query_mask: Mask indicating whether query is real (1) or padding (0), of
shape [batch_size, num_queries].
Returns:
Dict with keys 'class_embeddings' and, if query embeddings were provided,
'pred_logits'.
"""
image_class_emb
=
self
.
proj
(
x
)
if
query_embeddings
is
None
:
return
{
"class_embeddings"
:
image_class_emb
}
if
self
.
normalize
:
image_class_emb
/=
image_class_emb
.
norm
(
axis
=-
1
,
keepdims
=
True
)
+
1e-6
query_embeddings
/=
query_embeddings
.
norm
(
axis
=-
1
,
keepdims
=
True
)
+
1e-6
pred_logits
=
paddle
.
matmul
(
x
=
image_class_emb
,
y
=
query_embeddings
,
transpose_y
=
True
)
logit_shift
=
self
.
logit_shift
(
x
)
logit_scale
=
F
.
elu
(
self
.
logit_scale
(
x
))
+
1
pred_logits
=
(
logit_shift
+
pred_logits
)
*
logit_scale
if
query_mask
is
not
None
:
if
len
(
query_mask
.
shape
)
>
1
:
query_mask
=
query_mask
.
unsqueeze
(
-
2
)
pred_logits
=
paddle
.
where
(
query_mask
==
0
,
-
1e6
,
pred_logits
)
return
pred_logits
,
image_class_emb
@
register
class
OWLViTHead
(
nn
.
Layer
):
__inject__
=
[
'class_head, bbox_head'
,
'loss'
]
def
__init__
(
self
,
class_head
,
bbox_head
,
loss
,
box_bias
=
'both'
):
super
().
__init__
()
self
.
class_head
=
class_head
self
.
bbox_head
=
bbox_head
self
.
box_bias
=
box_bias
self
.
matcher
=
matcher
self
.
loss
=
loss
def
box_predictor
(
self
,
image_features
,
feature_map
):
"""Predicts bounding boxes from image features.
Args:
image_features: Feature tokens extracted from the image, returned by the
`embedder` function.
feature_map: A spatial re-arrangement of image_features, also returned by
the `embedder` function.
Returns:
List of predicted boxes (cxcywh normalized to 0, 1) nested within
a dictionary.
"""
# Bounding box detection head [b, num_patches, 4].
pred_boxes
=
self
.
obj_box_head
(
image_features
)
# We compute the location of each token on the grid and use it to compute
# a bias for the bbox prediction, i.e., each token is biased towards
# predicting its location on the grid as the center.
pred_boxes
+=
compute_box_bias
(
feature_map
,
kind
=
self
.
box_bias
)
pred_boxes
=
nn
.
sigmoid
(
pred_boxes
)
return
pred_boxes
def
class_predictor
(
self
,
image_features
,
query_embeddings
=
None
,
query_mask
=
None
):
"""Applies the class head to the image features.
Args:
image_features: Feature tokens extracted by the image embedder.
query_embeddings: Optional list of text (or image) embeddings. If no
embeddings are provided, no logits will be computed and only the class
embeddings for the image will be returned.
query_mask: Must be provided with query_embeddings. A mask indicating
which query embeddings are valid.
Returns:
A dictionary containing the class_embeddings and the pred_logits if
query_embeddings and query_mask are provided.
"""
return
self
.
class_head
(
image_features
,
query_embeddings
,
query_mask
)
def
forward
(
self
,
feature_map
,
query_embeddings
,
targets
=
None
):
b
,
c
,
h
,
w
=
feature_map
.
shape
image_features
=
paddle
.
reshape
(
feature_map
,
(
b
,
c
,
h
*
w
))
pred_boxes
=
self
.
box_predictor
(
image_features
,
feature_map
)
query_mask
=
(
text_queries
[...,
0
]
>
0
).
cast
(
paddle
.
float32
)
pred_logits
,
image_class_emb
=
self
.
class_predictor
(
image_features
,
query_embeddings
,
query_mask
)
if
self
.
training
:
return
self
.
get_loss
([
pred_boxes
,
pred_logits
],
targets
)
else
:
return
self
.
get_pred
(
pred_boxes
,
pred_logits
)
def
get_loss
(
self
,
head_outs
,
gt_meta
):
return
self
.
loss
(
head_outs
,
gt_meta
)
ppdet/modeling/vl/loss/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ppdet/modeling/vl/loss/owl_vit_loss.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
from
ppdet.modeling.losses.iou_loss
import
GIoULoss
from
ppdet.modeling.transformers
import
bbox_cxcywh_to_xyxy
,
sigmoid_focal_loss
__all__
=
[
'OWLViTLoss'
]
@
register
class
OWLViTLoss
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
]
__inject__
=
[
'HungarianMatcher'
]
def
__init__
(
self
,
num_classes
,
matcher
=
'HungarianMatcher'
,
normalization
=
'per_example'
,
loss_coeff
=
None
,
use_focal_loss
=
None
,
alpha
=
None
,
gamma
=
None
):
super
().
__init__
()
self
.
giou_loss
=
GIoULoss
()
self
.
num_classes
=
num_classes
self
.
matcher
=
matcher
self
.
loss_coeff
=
matcher
.
matcher_coeff
if
loss_coeff
is
None
else
loss_coeff
self
.
use_focal_loss
=
matcher
.
use_focal_loss
if
use_focal_loss
is
None
else
use_focal_loss
self
.
alpha
=
matcher
.
alpha
if
alpha
is
None
else
alpha
self
.
gamma
=
matcher
.
gamma
if
gamma
is
None
else
gamma
assert
normalization
in
[
'per_example'
,
'global'
],
f
'
{
normalization
}
should be in [pre_example, global]'
self
.
normalization
=
normalization
def
_get_loss_class
(
self
,
logits
,
gt_class
,
match_indices
):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
target_label
=
paddle
.
full
(
logits
.
shape
[:
2
],
self
.
num_classes
,
dtype
=
'int64'
)
bs
,
num_query_objects
=
target_label
.
shape
if
sum
(
len
(
a
)
for
a
in
gt_class
)
>
0
:
index
,
updates
=
self
.
_get_index_updates
(
num_query_objects
,
gt_class
,
match_indices
)
target_label
=
paddle
.
scatter
(
target_label
.
reshape
([
-
1
,
1
]),
index
,
updates
.
astype
(
'int64'
))
target_label
=
target_label
.
reshape
([
bs
,
num_query_objects
])
if
self
.
use_focal_loss
:
target_label
=
F
.
one_hot
(
target_label
,
self
.
num_classes
+
1
)[...,
:
-
1
]
if
self
.
use_focal_loss
:
loss_cls
=
F
.
sigmoid_focal_loss
(
logits
,
target_label
,
alpha
=
self
.
alpha
,
gamma
=
self
.
gamma
,
reduction
=
'none'
)
else
:
loss_cls
=
F
.
cross_entropy
(
logits
,
target_label
,
reduction
=
'none'
)
return
loss_cls
.
sum
(
axis
=
[
1
,
2
])
def
_get_loss_bbox
(
self
,
boxes
,
gt_bbox
,
match_indices
):
src_bbox
,
target_bbox
=
self
.
_get_src_target_assign
(
boxes
,
gt_bbox
,
match_indices
)
src_box
=
bbox_cxcywh_to_xyxy
(
src_bbox
)
target_bbox
=
bbox_cxcywh_to_xyxy
(
target_bbox
)
loss_bbox
=
F
.
l1_loss
(
src_bbox
,
target_bbox
,
reduction
=
'none'
)
loss_giou
=
self
.
giou_loss
(
src_bbox
,
target_bbox
)
return
loss_bbox
.
sum
(
axis
=
1
),
loss_giou
.
sum
(
axis
=
1
)
def
_get_src_target_assign
(
self
,
src
,
target
,
match_indices
):
src_assign
=
paddle
.
concat
([
paddle
.
gather
(
t
,
I
,
axis
=
0
)
if
len
(
I
)
>
0
else
paddle
.
zeros
([
0
,
t
.
shape
[
-
1
]])
for
t
,
(
I
,
_
)
in
zip
(
src
,
match_indices
)
])
target_assign
=
paddle
.
concat
([
paddle
.
gather
(
t
,
J
,
axis
=
0
)
if
len
(
J
)
>
0
else
paddle
.
zeros
([
0
,
t
.
shape
[
-
1
]])
for
t
,
(
_
,
J
)
in
zip
(
target
,
match_indices
)
])
return
src_assign
,
target_assign
def
forward
(
self
,
head_outs
,
gt_meta
):
logits
,
boxes
=
head_outs
gt_class
,
gt_bbox
=
gt_meta
[
'gt_class'
],
gt_meta
[
'gt_bbox'
]
match_indices
=
self
.
matcher
(
boxes
.
detach
(),
logits
.
detach
(),
gt_bbox
,
gt_class
)
loss_cls
=
self
.
_get_loss_class
(
logits
,
gt_class
,
match_indices
)
loss_bbox
,
loss_giou
=
self
.
_get_loss_bbox
(
boxes
,
gt_bbox
,
match_indices
)
num_gts
=
paddle
.
to_tensor
([
len
(
a
)
for
a
in
gt_class
])
if
self
.
normalization
==
'per_example'
:
num_gts
=
paddle
.
clip
(
num_gts
,
min
=
1
)
loss_cls
=
(
loss_cls
/
num_gts
).
mean
()
loss_bbox
=
(
loss_bbox
/
num_gts
).
mean
()
loss_giou
=
(
loss_giou
/
num_gts
).
mean
()
# normalize_fn = lambda x : (x / num_gts).mean()
else
:
num_gts
=
paddle
.
distributed
.
all_reduce
(
num_gts
)
num_gts
=
paddle
.
clip
(
num_gts
/
paddle
.
distributed
.
get_world_size
(),
min
=
1
)
loss_cls
=
loss_cls
.
sum
()
/
num_gts
loss_bbox
=
loss_bbox
.
sum
()
/
num_gts
loss_giou
=
loss_giou
.
sum
()
/
num_gts
# normalize_fn = lambda x: x.sum() / num_gts
# loss_cls, loss_box, loss_giou = [normalize_fn(l) for l in [loss_cls, loss_box, loss_giou]]
loss
=
self
.
loss_coeff
[
'class'
]
*
loss_cls
+
\
self
.
loss_coeff
[
'bbox'
]
*
loss_bbox
+
\
self
.
loss_coeff
[
'giou'
]
*
loss_giou
return
{
'loss'
:
loss
,
'loss_cls'
:
loss_cls
,
'loss_bbox'
:
loss_bbox
,
'loss_giou'
:
loss_giou
}
ppdet/modeling/vl/matcher/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppdet.modeling.transformers.matchers
import
HungarianMatcher
\ No newline at end of file
ppdet/modeling/vl/models/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.owl_vit
import
OWLViT
\ No newline at end of file
ppdet/modeling/vl/models/owl_vit.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
from
ppdet.modeling.architectures
import
BaseArch
from
..utils
import
seq2img
from
..tokenizer
import
tokenize
@
register
class
OWLViT
(
BaseArch
):
__category__
=
'architecture'
def
__init__
(
self
,
embedder
,
head
):
super
().
__init__
()
self
.
backbone
=
embedder
self
.
head
=
head
def
tokenize
(
self
,
text
,
max_token_len
):
return
tokenize
(
text
,
max_token_len
)
def
image_embedder
(
self
,
images
):
"""Embeds images into feature maps.
Args:
images: images of shape (batch, input_size, input_size, 3), scaled to the
input range defined in the config. Padding should be at the bottom right
of the image.
Returns:
A 2D map of image features.
"""
image_features
,
_
=
self
.
backbone
(
images
=
images
)
return
seq2img
(
images
,
image_features
)
def
text_embedder
(
self
,
text_queries
):
"""Embeds text into features.
Args:
text_queries: int32 tokenized text queries of shape [..., num_tokens].
Returns:
An array of the same shape as text_queries, except for the last dimension,
which is num_dimensions instead of num_tokens.
"""
_
,
text_features
=
self
.
backbone
(
texts
=
text_queries
)
return
text_features
def
forward
(
self
,
inputs
,
text_queries
):
"""Applies TextZeroShotDetectionModule on the input.
Args:
inputs: Images [batch_size, height, width, 3].
text_queries: Queries to score boxes on. Queries starting with 0 stand for
padding [batch_size=b, num_queries=q, max_query_length=l].
Returns:
Outputs dict with items:
pred_logits: Class logits [b, num_patches, num_queries].
pred_boxes: Predicted bounding boxes [b, num_patches, 4].
feature_map: Image embeddings 2d feature map [b, sp, sp, img_emb_dim].
"""
# Embed images:
feature_map
=
self
.
image_embedder
(
inputs
)
# Embed queries:
query_embeddings
=
self
.
text_embedder
(
text_queries
)
outputs
=
self
.
head
(
feature_map
,
query_embeddings
)
return
outputs
ppdet/modeling/vl/tokenizer/__init__.py
0 → 100644
浏览文件 @
27395ac8
from
.simple_tokenizer
import
*
ppdet/modeling/vl/tokenizer/simple_tokenizer.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
gzip
import
html
import
os
from
functools
import
lru_cache
import
ftfy
import
regex
as
re
__all__
=
[
'SimpleTokenizer'
,
'tokenize'
]
@
lru_cache
()
def
default_bpe
():
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
]
*
4
)))
return
os
.
path
.
join
(
parent_path
,
"bpe_simple_vocab_16e6.txt.gz"
)
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
def
basic_clean
(
text
):
text
=
ftfy
.
fix_text
(
text
)
text
=
html
.
unescape
(
html
.
unescape
(
text
))
return
text
.
strip
()
def
whitespace_clean
(
text
):
text
=
re
.
sub
(
r
'\s+'
,
' '
,
text
)
text
=
text
.
strip
()
return
text
class
SimpleTokenizer
(
object
):
def
__init__
(
self
,
bpe_path
=
default_bpe
()):
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
merges
=
gzip
.
open
(
bpe_path
).
read
().
decode
(
"utf-8"
).
split
(
'
\n
'
)
merges
=
merges
[
1
:
49152
-
256
-
2
+
1
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
vocab
=
list
(
bytes_to_unicode
().
values
())
vocab
=
vocab
+
[
v
+
'</w>'
for
v
in
vocab
]
for
merge
in
merges
:
vocab
.
append
(
''
.
join
(
merge
))
vocab
.
extend
([
'<|startoftext|>'
,
'<|endoftext|>'
])
self
.
encoder
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{
'<|startoftext|>'
:
'<|startoftext|>'
,
'<|endoftext|>'
:
'<|endoftext|>'
}
self
.
pat
=
re
.
compile
(
r
"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""
,
re
.
IGNORECASE
)
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
'</w>'
,
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
+
'</w>'
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
encode
(
self
,
text
):
bpe_tokens
=
[]
text
=
whitespace_clean
(
basic_clean
(
text
)).
lower
()
for
token
in
re
.
findall
(
self
.
pat
,
text
):
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
self
.
encoder
[
bpe_token
]
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
decode
(
self
,
tokens
):
text
=
''
.
join
([
self
.
decoder
[
token
]
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
"replace"
).
replace
(
'</w>'
,
' '
)
return
text
def
tokenize
(
text
,
max_token_len
):
tokenizer
=
build_tokenizer
()
sot_token
=
tokenizer
.
encoder
[
'<|startoftext|>'
]
eot_token
=
tokenizer
.
encoder
[
'<|endoftext|>'
]
tokens
=
[
sot_token
]
+
tokenizer
.
encode
(
text
)
+
[
eot_token
]
output
=
[
0
]
*
max_token_len
output
[:
min
(
max_token_len
,
len
(
tokens
))]
=
tokens
[:
max_token_len
]
return
output
@
functools
.
lru_cache
(
maxsize
=
1
)
def
build_tokenizer
(
bpe_path
=
default_bpe
()):
return
simple_tokenizer
.
SimpleTokenizer
(
bpe_path
)
ppdet/modeling/vl/utils/__init__.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.utils
import
*
\ No newline at end of file
ppdet/modeling/vl/utils/utils.py
0 → 100644
浏览文件 @
27395ac8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
IMAGE_MEAN
=
paddle
.
to_tensor
([
0.48145466
,
0.4578275
,
0.40821073
])
IMAGE_STD
=
paddle
.
to_tensor
([
0.26862954
,
0.26130258
,
0.27577711
])
def
normalize_image
(
img
):
return
(
img
-
IMAGE_MEAN
)
/
IMAGE_STD
def
unnormalize_image
(
x
):
return
x
*
IMAGE_STD
+
IMAGE_MEAN
def
resize_posemb
(
posemb
,
target_size
):
"""Resizes position embeddings to new resolution."""
if
target_size
==
posemb
.
shape
[
1
]:
return
posemb
gs_old
=
int
(
np
.
sqrt
(
posemb
.
shape
[
1
]))
gs_new
=
int
(
np
.
sqrt
(
target_size
))
posemb_tok
=
None
if
gs_old
**
2
==
posemb
.
shape
[
1
]:
posemb_grid
=
posemb
elif
gs_old
**
2
==
posemb
.
shape
[
1
]
-
1
:
posemb_tok
,
posemb_grid
=
posemb
[:,
:
1
],
posemb
[:,
1
:]
else
:
raise
ValueError
(
'Posemb shape must be a perfect square (maybe with CLS token), but '
f
'got posemb of shape
{
posemb
.
shape
}
.'
)
posemb_grid
=
posemb_grid
.
reshape
(
1
,
gs_old
,
gs_old
,
-
1
).
transpose
(
[
0
,
3
,
1
,
2
])
posemb_grid
=
F
.
interpolate
(
posemb_grid
,
size
=
gs_new
,
mode
=
'bilinear'
,
align_corners
=
False
)
posemb_grid
=
posemb_grid
.
transpose
([
0
,
2
,
3
,
1
]).
reshape
(
1
,
gs_new
[
0
]
*
gs_new
[
1
],
-
1
)
if
posemb_tok
is
not
None
:
posemb
=
paddle
.
concat
([
posemb_tok
,
posemb
],
axis
=
1
)
return
posemb
def
seq2img
(
original_img
,
features
):
"""Reshapes 1D sequence to 2D image features."""
if
original_img
.
shape
[
2
]
==
original_img
.
shape
[
3
]:
h
=
w
=
int
(
np
.
sqrt
(
features
.
shape
[
2
]))
else
:
stride
=
np
.
ceil
(
np
.
sqrt
(
original_img
.
shape
[
2
]
*
original_img
.
shape
[
3
]
/
features
.
shape
[
2
]))
h
=
np
.
ceil
(
original_img
.
shape
[
2
]
/
stride
)
w
=
np
.
ceil
(
original_img
.
shape
[
3
]
/
stride
)
return
features
.
reshape
([
features
.
shape
[
0
],
-
1
,
int
(
h
),
int
(
w
)])
def
normalized_grid_corner_coordinates
(
feature_map
,
padding_mask
):
"""Computes normalized xy corner coords from feature_map or padding_mask."""
# Note 1: it computes not the centers of grid patches, but the patch corner
# coordinates (for a grid patch from 0 to 0.1, it returns 0.1 not 0.05).
# Note 2: behavior is quite different for feature_map and padding_mask inputs.
if
padding_mask
is
None
:
assert
len
(
feature_map
.
shape
)
==
4
# [B, C, H, W]
_
,
_
,
h
,
w
=
paddle
.
shape
(
feature_map
)
shift_x
=
paddle
.
arange
(
1
,
w
+
1
)
shift_y
=
paddle
.
arange
(
1
,
h
+
1
)
shift_y
,
shift_x
=
paddle
.
meshgrid
(
shift_y
,
shift_x
)
# [H, W, 2]
xy
=
paddle
.
cast
(
paddle
.
stack
(
[
shift_x
,
shift_y
],
axis
=-
1
),
dtype
=
'float32'
)
xy
=
xy
/
paddle
.
concat
([
w
,
h
])
else
:
assert
len
(
padding_mask
.
shape
)
==
3
# [B, H, W]
padding_mask
=
padding_mask
.
cast
(
paddle
.
float32
)
y
=
paddle
.
cumsum
(
padding_mask
,
axis
=
1
)
x
=
paddle
.
cumsum
(
padding_mask
,
axis
=
2
)
# [B, H, W, 2]
xy
=
paddle
.
stack
(
[
x
/
(
x
[:,
:,
-
1
:]
+
1e-6
),
y
/
(
y
[:,
-
1
:]
+
1e-6
)],
axis
=-
1
)
return
xy
.
reshape
(
xy
.
shape
[:
-
3
]
+
[
-
1
,
2
])
def
compute_box_bias
(
feature_map
,
padding_mask
,
kind
=
'both'
):
"""Computes spatial bias for grid."""
# The box center is biased to its position on the feature grid:
xy
=
normalized_grid_corner_coordinates
(
feature_map
,
padding_mask
)
xy
=
paddle
.
clip
(
xy
,
0.0
,
1.0
)
if
kind
in
[
'both'
,
'location'
]:
# Unnormalize xy (i.e., apply logit function/sigmoid^-1).
xy_bias
=
logit
(
xy
)
else
:
xy_bias
=
paddle
.
zeros_like
(
xy
)
if
kind
in
[
'both'
,
'size'
]:
# The box size is biased to the patch size:
wh_bias
=
logit
(
paddle
.
full_like
(
xy_bias
,
1.0
/
feature_map
.
shape
[
-
1
]))
else
:
wh_bias
=
paddle
.
zeros_like
(
xy_bias
)
return
paddle
.
concat
([
xy_bias
,
wh_bias
],
axis
=-
1
)
def
logit
(
x
,
eps
=
1e-4
):
"""Logit (inverse sigmoid) function (https://en.wikipedia.org/wiki/Logit)."""
return
paddle
.
log
(
x
+
eps
)
-
paddle
.
log1p
(
-
x
+
eps
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录