提交 27395ac8 编写于 作者: W wangxinxin08

add owl-vit code

上级 45ba40b4
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .models import OWLViT
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
__all__ = ['ClipImageTextEmbedder']
@register
class ClipImageTextEmbedder(nn.Layer):
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
def __init__(self, base_model, embed_dim, merge_class_token='drop'):
super().__init__()
self.clip = base_model
self.merge_class_token = merge_class_token
if self.merge_class_token == 'mul-ln':
self.merged_class_token = nn.LayerNorm(embed_dim)
def forward(self, images, texts):
if texts is not None:
texts_shape = texts.shape
if len(texts_shape) > 2:
texts = texts.reshape(-1, texts_shape[-1])
if images is not None:
images = normalize_image(images)
img_emb, txt_emb = self.clip(images, texts, normalize=False)
if img_emb is not None:
if self.merge_class_token == 'drop':
img_emb = img_emb[:, 1:, :]
elif self.merge_class_token == 'mul-ln':
img_emb = img_emb[:, :1, :] * img_emb[:, 1:, :]
img_emb = self.merged_class_token(img_emb)
else:
raise ValueError(
f'Unknown merge_class_token: {self.merge_class_token}')
if txt_emb is not None and len(texts_shape) > 2:
txt_emb = txt_emb.reshape(texts_shape[:-1] + [-1, ])
return img_emb, txt_emb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .models import ModifiedResNet, TextEncoder, VisionTransformer
from .layers import LayerNorm, QuickGELU, AttentionPool2D
from .clip import CLIP
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.modeling.layers import MultiHeadAttention
from ppdet.modeling.initializer import zeros_, normal_
from ppdet.core.workspace import register
from .models import ModifiedResNet, VisionTransformer, TextEncoder
@register
class CLIP(nn.Layer):
__inject__ = ['image_encoder', 'text_encoder']
def __init__(self, image_encoder, text_encoder):
super().__init__()
self.visual = image_encoder
self.text = text_encoder
self.initialize_parameters()
def initialize_parameters(self):
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.weight.shape[0]**-0.5
normal_(self.visual.attnpool.q_proj.weight, std=std)
normal_(self.visual.attnpool.k_proj.weight, std=std)
normal_(self.visual.attnpool.v_proj.weight, std=std)
normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
zeros_(param)
normal_(self.text.token_embedding.weight, std=0.02)
normal_(self.text.positional_embedding, std=0.01)
proj_std = (self.text.transformer.width**-0.5) * (
(2 * self.text.transformer.layers)**-0.5)
attn_std = self.text.transformer.width**-0.5
fc_std = (2 * self.text.transformer.width)**-0.5
for block in self.text.transformer.resblocks:
normal_(block.attn.in_proj_weight, std=attn_std)
normal_(block.attn.out_proj.weight, std=proj_std)
normal_(block.mlp.c_fc.weight, std=fc_std)
normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text.text_projection is not None:
normal_(
self.text.text_projection.weight,
std=self.text.transformer.width**-0.5)
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.cast(self.dtype))
def encode_text(self, text):
return self.text(text.cast(self.dtype))
def forward(self, image, text, normalize=True):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
if normalize:
image_features /= image_features.norm(axis=1, keepdim=True)
text_features /= image_features.norm(axis=1, keepdim=True)
return image_fetaures, text_features
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.modeling.layers import MultiHeadAttention
from ppdet.modeling.initializer import zeros_, normal_
# ResNet
class Bottleneck(nn.Layer):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes)
self.relu2 = nn.ReLU()
self.avgpool = nn.AvgPool2D(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2D(
planes, planes * self.expansion, 1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(planes * self.expansion)
self.relu3 = nn.ReLU()
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([("-1", nn.AvgPool2D(stride)), ("0", nn.Conv2D(
inplanes,
planes * self.expansion,
1,
stride=1,
bias_attr=False)), ("1", nn.BatchNorm2D(planes *
self.expansion))]))
def forward(self, x):
dentity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2D(nn.Module):
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim):
super().__init__()
# TODO: need check whether it is consistent with torch or not
self.positional_embedding = self.create_parameter(
shape=[spacial_dim**2 + 1, embed_dim],
attr=ParamAttr(initializer=Normal(std=1. / embed_dim**0.5)))
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
def forward(self, x):
# [N, C, H, W] -> [N, C, HW] -> [N, HW, C]
x = x.flatten(start_axis=2).transpose([0, 2, 1])
# [N, 1, C] + [N, HW, C] = [N, HW+1, C]
x = paddle.concat([x.mean(axis=1, keepdim=True), x], axis=1)
# [N, HW+1, C]
x = x + self.positional_embedding.unsqueeze(0)
# compute q, k, v
q = self.q_proj(x[:, :1, :])
k = self.k_proj(x)
v = self.v_proj(x)
# [N, 1, C] -> [N, 1, num_heads, head_dim] -> [N, num_heads, 1, head_dim]
q = q.reshape([0, 0, self.num_heads, self.head_dim]).transpose(
[0, 2, 1, 3])
# [N, HW+1, C] -> [N, HW+1, num_heads, head_dim] -> [N, num_heads, HW+1, head_dim]
k = k.reshape([0, 0, self.num_heads, self.head_dim]).transpose(
[0, 2, 1, 3])
v = v.reshape([0, 0, self.num_heads, self.head_dim]).transpose(
[0, 2, 1, 3])
# [N, num_heads, 1, HW+1]
product = paddle.matmul(x=q, y=k, transpose_y=True)
scaling = float(self.head_dim)**-0.5
product = product * scaling
weights = F.softmax(product)
# [N, num_heads, 1, head_dim]
out = paddle.matmul(weights, v)
# [N, num_heads, 1, head_dim] -> [N, 1, num_heads, head_dim] -> [N, embed_dim]
out = out.transpose([0, 2, 1, 3]).reshape([0, self.embed_dim])
return out
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x):
orig_type = x.dtype
ret = super().forward(x.cast(paddle.float32))
return ret.cast(orig_type)
class QuickGELU(nn.Layer):
def forward(self, x):
return x * F.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Layer):
def __init__(self, d_model, n_head, droplayer_p=0.0, attn_mask=None):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), (
"gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)
)]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.droplayer_p = droplayer_p
def get_drop_pattern(self, x):
if self.training and self.droplayer_p:
shape = (x.shape[0], ) + (1, ) * (len(x.shape) - 1)
p = self.droplayer_p * paddle.ones(shape)
return paddle.bernoulli(p)
else:
return 0.0
def attention(self, x):
self.attn_mask = self.attn_mask.cast(
dtype=x.dtype) if self.attn_mask is not None else None
return self.attn(x, x, x, attn_mask=self.attn_mask)
def forward(self, x):
y = self.attention(self.ln_1(x))
drop_pattern = self.get_drop_pattern(y)
x = x + y * (1.0 - drop_pattern)
y = self.mlp(self.ln_2(x))
drop_pattern = self.get_drop_pattern(y)
x = x + y * (1.0 - drop_pattern)
return x
class Transformer(nn.Layer):
def __init__(self,
width,
layers,
heads,
stochastic_droplayer_rate=0.0,
attn_mask=None):
super().__init__()
self.width = width
self.layers = layers
blocks = []
for i in range(self.layers):
droplayer_p = (i / max(self.layers - 1,
1)) * self.stochastic_droplayer_rate
blocks.append(
ResidualAttentionBlock(width, heads, droplayer_p, attn_mask))
self.resblocks = nn.Sequential(*blocks)
def forward(self, x):
return self.resblocks(x)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.modeling.initializer import zeros_, normal_
from ppdet.core.workspace import register
from .layers import *
__all__ = ['ModifiedResNet', 'VisionTransformer', 'TextEncoder']
@register
class ModifiedResNet(nn.Layer):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2D(
3, width // 2, kernel_size=3, stride=2, padding=1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(width // 2)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
width // 2, width // 2, kernel_size=3, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(width // 2)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2D(
width // 2, width, kernel_size=3, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(width)
self.relu3 = nn.ReLU()
self.avgpool = nn.AvgPool2D(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2D(input_resolution // 32, embed_dim,
heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = x.cast(self.conv1.weight.dtype)
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
@register
class VisionTransformer(nn.Layer):
def __init__(self,
input_resolution,
patch_size,
width,
layers,
heads,
output_dim=None,
stochastic_droplayer_rate=0.0):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2D(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = self.create_parameter(
shape=[width], attr=ParamAttr(initializer=Normal(std=scale)))
self.positional_embedding = self.create_parameter(
shape=[(input_resolution // patch_size)**2 + 1, width],
attr=ParamAttr(initializer=Normal(std=scale)))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads,
stochastic_droplayer_rate)
self.ln_post = LayerNorm(width)
if output_dim is not None:
self.proj = nn.Linear(self.width, self.output_dim, bias_attr=False)
def forward(self, x):
x = self.conv1(x)
x = x.reshape([x.shape[0], x.shape[1], -1])
x = x.transpose([0, 2, 1])
class_embedding = self.class_embedding.cast(x.dtype) + paddle.zeros(
[x.shape[0], 1, x.shape[-1]], type=x.dtype)
x = paddle.concat([class_embedding, x], axis=1)
x = x + self.positional_embedding.cast(x.dtype)
x = self.ln_pre(x)
x = feature = self.transformer(x)
if self.output_dim is not None:
x = self.ln_post(x[:, 0, :])
x = self.proj(x)
else:
x = self.ln_post(x)
return x, feature
@register
class TextEncoder(nn.Layer):
def __init__(self, context_length, vocab_size, transformer_width,
transformer_heads, transformer_layers,
stochastic_droplayer_rate):
super().__init__()
self.context_length = context_length
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
stochastic_droplayer_rate=stochastic_droplayer_rate,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = self.create_parameter(
shape=[transformer_width, embed_dim],
attr=ParamAttr(initializer=Constant(0.0)))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Linear(
transformer_width, embed_dim, bias_attr=False)
self.logit_scale = self.create_parameter(
shape=[], attr=ParamAttr(initializer=Constant(np.log(1. / 0.07))))
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = paddle.full((self.context_length, self.context_length),
float("-inf"))
mask = paddle.triu(mask)
return mask
def forward(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.cast(x.dtype)
x = self.transformer(x)
x = self.ln_final(x).cast(x.dtype)
# x.shape = [batch_size, text_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
batch_idx = paddle.arange(x.shape(0))
seq_idx = text.argmax(dim=-1)
gather_idx = paddle.stack([batch_idx, seq_idx], axis=1)
x = paddle.gather_nd(x, gather_idx)
x = self.text_projection(x)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.modeling.ops import get_act_fn
from ..utils import compute_box_bias
__all__ = ['PredictorMLP', 'ClassPredictor', 'OWLViTHead']
@register
class PredictorMLP(nn.Layer):
"""FFN block for predicting continuous outputs, e.g. bounding box coordinates.
Attributes:
out_dim: Size of output of this mlp.
num_layers: Number of layers.
mlp_dim: Size of hidden dimension of dense layers.
hidden_activation: Activation function of hidden layers.
out_activation: Activation of the output.
dtype: Data type, e.g. jnp.float32.
"""
def __init__(self,
in_dim,
out_dim,
num_layers,
mlp_dim,
hidden_activation,
out_activation=None):
super().__init__()
layers = []
for _ in range(num_layers - 1):
layers.append(nn.Linear(in_dim, mlp_dim))
in_dim = mlp_dim
layers.append(nn.Linear(in_dim, out_dim))
self.mlp = nn.LayerList(layers)
self.num_layers = num_layers
self.hidden_activation = get_act_fn(hidden_activation)
self.out_activation = get_act_fn(out_activation)
def forward(self, inputs):
x = inputs
for _ in range(self.num_layers - 1):
x = self.mlp[i](x)
x = self.hidden_activation(x)
x = self.mlp[-1](x)
x = self.out_activation(x)
return x
@register
class ClassPredictor(nn.Layer):
"""Open-vocabulary instance class predictor."""
def __init__(self, in_dim, out_dim, normalize):
super().__init__()
self.normalize = normalize
self.out_dim = out_dim
self.proj = nn.Linear(in_dim, out_dim)
self.logit_shift = nn.Linear(in_dim, 1)
self.logit_scale = nn.Linear(in_dim, 1)
def forward(self, x, query_embeddings=None, query_mask=None):
"""Computes class prediction logits.
Query embeddings from a text encoder define the classification label space.
Args:
x: Image features [batch_size, num_patches, emb_dim].
query_embeddings: The embeddings to classify against of shape [batch_size,
num_queries, out_dim]. If not specified, only the image class embeddings
will be returned.
query_mask: Mask indicating whether query is real (1) or padding (0), of
shape [batch_size, num_queries].
Returns:
Dict with keys 'class_embeddings' and, if query embeddings were provided,
'pred_logits'.
"""
image_class_emb = self.proj(x)
if query_embeddings is None:
return {"class_embeddings": image_class_emb}
if self.normalize:
image_class_emb /= image_class_emb.norm(
axis=-1, keepdims=True) + 1e-6
query_embeddings /= query_embeddings.norm(
axis=-1, keepdims=True) + 1e-6
pred_logits = paddle.matmul(
x=image_class_emb, y=query_embeddings, transpose_y=True)
logit_shift = self.logit_shift(x)
logit_scale = F.elu(self.logit_scale(x)) + 1
pred_logits = (logit_shift + pred_logits) * logit_scale
if query_mask is not None:
if len(query_mask.shape) > 1:
query_mask = query_mask.unsqueeze(-2)
pred_logits = paddle.where(query_mask == 0, -1e6, pred_logits)
return pred_logits, image_class_emb
@register
class OWLViTHead(nn.Layer):
__inject__ = ['class_head, bbox_head', 'loss']
def __init__(self, class_head, bbox_head, loss, box_bias='both'):
super().__init__()
self.class_head = class_head
self.bbox_head = bbox_head
self.box_bias = box_bias
self.matcher = matcher
self.loss = loss
def box_predictor(self, image_features, feature_map):
"""Predicts bounding boxes from image features.
Args:
image_features: Feature tokens extracted from the image, returned by the
`embedder` function.
feature_map: A spatial re-arrangement of image_features, also returned by
the `embedder` function.
Returns:
List of predicted boxes (cxcywh normalized to 0, 1) nested within
a dictionary.
"""
# Bounding box detection head [b, num_patches, 4].
pred_boxes = self.obj_box_head(image_features)
# We compute the location of each token on the grid and use it to compute
# a bias for the bbox prediction, i.e., each token is biased towards
# predicting its location on the grid as the center.
pred_boxes += compute_box_bias(feature_map, kind=self.box_bias)
pred_boxes = nn.sigmoid(pred_boxes)
return pred_boxes
def class_predictor(self,
image_features,
query_embeddings=None,
query_mask=None):
"""Applies the class head to the image features.
Args:
image_features: Feature tokens extracted by the image embedder.
query_embeddings: Optional list of text (or image) embeddings. If no
embeddings are provided, no logits will be computed and only the class
embeddings for the image will be returned.
query_mask: Must be provided with query_embeddings. A mask indicating
which query embeddings are valid.
Returns:
A dictionary containing the class_embeddings and the pred_logits if
query_embeddings and query_mask are provided.
"""
return self.class_head(image_features, query_embeddings, query_mask)
def forward(self, feature_map, query_embeddings, targets=None):
b, c, h, w = feature_map.shape
image_features = paddle.reshape(feature_map, (b, c, h * w))
pred_boxes = self.box_predictor(image_features, feature_map)
query_mask = (text_queries[..., 0] > 0).cast(paddle.float32)
pred_logits, image_class_emb = self.class_predictor(
image_features, query_embeddings, query_mask)
if self.training:
return self.get_loss([pred_boxes, pred_logits], targets)
else:
return self.get_pred(pred_boxes, pred_logits)
def get_loss(self, head_outs, gt_meta):
return self.loss(head_outs, gt_meta)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.losses.iou_loss import GIoULoss
from ppdet.modeling.transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
__all__ = ['OWLViTLoss']
@register
class OWLViTLoss(nn.Layer):
__shared__ = ['num_classes']
__inject__ = ['HungarianMatcher']
def __init__(self,
num_classes,
matcher='HungarianMatcher',
normalization='per_example',
loss_coeff=None,
use_focal_loss=None,
alpha=None,
gamma=None):
super().__init__()
self.giou_loss = GIoULoss()
self.num_classes = num_classes
self.matcher = matcher
self.loss_coeff = matcher.matcher_coeff if loss_coeff is None else loss_coeff
self.use_focal_loss = matcher.use_focal_loss if use_focal_loss is None else use_focal_loss
self.alpha = matcher.alpha if alpha is None else alpha
self.gamma = matcher.gamma if gamma is None else gamma
assert normalization in [
'per_example', 'global'
], f'{normalization} should be in [pre_example, global]'
self.normalization = normalization
def _get_loss_class(self, logits, gt_class, match_indices):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
target_label = paddle.full(
logits.shape[:2], self.num_classes, dtype='int64')
bs, num_query_objects = target_label.shape
if sum(len(a) for a in gt_class) > 0:
index, updates = self._get_index_updates(num_query_objects,
gt_class, match_indices)
target_label = paddle.scatter(
target_label.reshape([-1, 1]), index, updates.astype('int64'))
target_label = target_label.reshape([bs, num_query_objects])
if self.use_focal_loss:
target_label = F.one_hot(target_label,
self.num_classes + 1)[..., :-1]
if self.use_focal_loss:
loss_cls = F.sigmoid_focal_loss(
logits,
target_label,
alpha=self.alpha,
gamma=self.gamma,
reduction='none')
else:
loss_cls = F.cross_entropy(logits, target_label, reduction='none')
return loss_cls.sum(axis=[1, 2])
def _get_loss_bbox(self, boxes, gt_bbox, match_indices):
src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
match_indices)
src_box = bbox_cxcywh_to_xyxy(src_bbox)
target_bbox = bbox_cxcywh_to_xyxy(target_bbox)
loss_bbox = F.l1_loss(src_bbox, target_bbox, reduction='none')
loss_giou = self.giou_loss(src_bbox, target_bbox)
return loss_bbox.sum(axis=1), loss_giou.sum(axis=1)
def _get_src_target_assign(self, src, target, match_indices):
src_assign = paddle.concat([
paddle.gather(
t, I, axis=0) if len(I) > 0 else paddle.zeros([0, t.shape[-1]])
for t, (I, _) in zip(src, match_indices)
])
target_assign = paddle.concat([
paddle.gather(
t, J, axis=0) if len(J) > 0 else paddle.zeros([0, t.shape[-1]])
for t, (_, J) in zip(target, match_indices)
])
return src_assign, target_assign
def forward(self, head_outs, gt_meta):
logits, boxes = head_outs
gt_class, gt_bbox = gt_meta['gt_class'], gt_meta['gt_bbox']
match_indices = self.matcher(boxes.detach(),
logits.detach(), gt_bbox, gt_class)
loss_cls = self._get_loss_class(logits, gt_class, match_indices)
loss_bbox, loss_giou = self._get_loss_bbox(boxes, gt_bbox,
match_indices)
num_gts = paddle.to_tensor([len(a) for a in gt_class])
if self.normalization == 'per_example':
num_gts = paddle.clip(num_gts, min=1)
loss_cls = (loss_cls / num_gts).mean()
loss_bbox = (loss_bbox / num_gts).mean()
loss_giou = (loss_giou / num_gts).mean()
# normalize_fn = lambda x : (x / num_gts).mean()
else:
num_gts = paddle.distributed.all_reduce(num_gts)
num_gts = paddle.clip(
num_gts / paddle.distributed.get_world_size(), min=1)
loss_cls = loss_cls.sum() / num_gts
loss_bbox = loss_bbox.sum() / num_gts
loss_giou = loss_giou.sum() / num_gts
# normalize_fn = lambda x: x.sum() / num_gts
# loss_cls, loss_box, loss_giou = [normalize_fn(l) for l in [loss_cls, loss_box, loss_giou]]
loss = self.loss_coeff['class'] * loss_cls + \
self.loss_coeff['bbox'] * loss_bbox + \
self.loss_coeff['giou'] * loss_giou
return {
'loss': loss,
'loss_cls': loss_cls,
'loss_bbox': loss_bbox,
'loss_giou': loss_giou
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ppdet.modeling.transformers.matchers import HungarianMatcher
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .owl_vit import OWLViT
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.architectures import BaseArch
from ..utils import seq2img
from ..tokenizer import tokenize
@register
class OWLViT(BaseArch):
__category__ = 'architecture'
def __init__(self, embedder, head):
super().__init__()
self.backbone = embedder
self.head = head
def tokenize(self, text, max_token_len):
return tokenize(text, max_token_len)
def image_embedder(self, images):
"""Embeds images into feature maps.
Args:
images: images of shape (batch, input_size, input_size, 3), scaled to the
input range defined in the config. Padding should be at the bottom right
of the image.
Returns:
A 2D map of image features.
"""
image_features, _ = self.backbone(images=images)
return seq2img(images, image_features)
def text_embedder(self, text_queries):
"""Embeds text into features.
Args:
text_queries: int32 tokenized text queries of shape [..., num_tokens].
Returns:
An array of the same shape as text_queries, except for the last dimension,
which is num_dimensions instead of num_tokens.
"""
_, text_features = self.backbone(texts=text_queries)
return text_features
def forward(self, inputs, text_queries):
"""Applies TextZeroShotDetectionModule on the input.
Args:
inputs: Images [batch_size, height, width, 3].
text_queries: Queries to score boxes on. Queries starting with 0 stand for
padding [batch_size=b, num_queries=q, max_query_length=l].
Returns:
Outputs dict with items:
pred_logits: Class logits [b, num_patches, num_queries].
pred_boxes: Predicted bounding boxes [b, num_patches, 4].
feature_map: Image embeddings 2d feature map [b, sp, sp, img_emb_dim].
"""
# Embed images:
feature_map = self.image_embedder(inputs)
# Embed queries:
query_embeddings = self.text_embedder(text_queries)
outputs = self.head(feature_map, query_embeddings)
return outputs
from .simple_tokenizer import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
__all__ = ['SimpleTokenizer', 'tokenize']
@lru_cache()
def default_bpe():
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
return os.path.join(parent_path, "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(
range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path=default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i +
1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors="replace").replace('</w>', ' ')
return text
def tokenize(text, max_token_len):
tokenizer = build_tokenizer()
sot_token = tokenizer.encoder['<|startoftext|>']
eot_token = tokenizer.encoder['<|endoftext|>']
tokens = [sot_token] + tokenizer.encode(text) + [eot_token]
output = [0] * max_token_len
output[:min(max_token_len, len(tokens))] = tokens[:max_token_len]
return output
@functools.lru_cache(maxsize=1)
def build_tokenizer(bpe_path=default_bpe()):
return simple_tokenizer.SimpleTokenizer(bpe_path)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import *
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn.functional as F
IMAGE_MEAN = paddle.to_tensor([0.48145466, 0.4578275, 0.40821073])
IMAGE_STD = paddle.to_tensor([0.26862954, 0.26130258, 0.27577711])
def normalize_image(img):
return (img - IMAGE_MEAN) / IMAGE_STD
def unnormalize_image(x):
return x * IMAGE_STD + IMAGE_MEAN
def resize_posemb(posemb, target_size):
"""Resizes position embeddings to new resolution."""
if target_size == posemb.shape[1]:
return posemb
gs_old = int(np.sqrt(posemb.shape[1]))
gs_new = int(np.sqrt(target_size))
posemb_tok = None
if gs_old**2 == posemb.shape[1]:
posemb_grid = posemb
elif gs_old**2 == posemb.shape[1] - 1:
posemb_tok, posemb_grid = posemb[:, :1], posemb[:, 1:]
else:
raise ValueError(
'Posemb shape must be a perfect square (maybe with CLS token), but '
f'got posemb of shape {posemb.shape}.')
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).transpose(
[0, 3, 1, 2])
posemb_grid = F.interpolate(
posemb_grid, size=gs_new, mode='bilinear', align_corners=False)
posemb_grid = posemb_grid.transpose([0, 2, 3, 1]).reshape(1, gs_new[0] *
gs_new[1], -1)
if posemb_tok is not None:
posemb = paddle.concat([posemb_tok, posemb], axis=1)
return posemb
def seq2img(original_img, features):
"""Reshapes 1D sequence to 2D image features."""
if original_img.shape[2] == original_img.shape[3]:
h = w = int(np.sqrt(features.shape[2]))
else:
stride = np.ceil(
np.sqrt(original_img.shape[2] * original_img.shape[3] /
features.shape[2]))
h = np.ceil(original_img.shape[2] / stride)
w = np.ceil(original_img.shape[3] / stride)
return features.reshape([features.shape[0], -1, int(h), int(w)])
def normalized_grid_corner_coordinates(feature_map, padding_mask):
"""Computes normalized xy corner coords from feature_map or padding_mask."""
# Note 1: it computes not the centers of grid patches, but the patch corner
# coordinates (for a grid patch from 0 to 0.1, it returns 0.1 not 0.05).
# Note 2: behavior is quite different for feature_map and padding_mask inputs.
if padding_mask is None:
assert len(feature_map.shape) == 4 # [B, C, H, W]
_, _, h, w = paddle.shape(feature_map)
shift_x = paddle.arange(1, w + 1)
shift_y = paddle.arange(1, h + 1)
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
# [H, W, 2]
xy = paddle.cast(
paddle.stack(
[shift_x, shift_y], axis=-1), dtype='float32')
xy = xy / paddle.concat([w, h])
else:
assert len(padding_mask.shape) == 3 # [B, H, W]
padding_mask = padding_mask.cast(paddle.float32)
y = paddle.cumsum(padding_mask, axis=1)
x = paddle.cumsum(padding_mask, axis=2)
# [B, H, W, 2]
xy = paddle.stack(
[x / (x[:, :, -1:] + 1e-6), y / (y[:, -1:] + 1e-6)], axis=-1)
return xy.reshape(xy.shape[:-3] + [-1, 2])
def compute_box_bias(feature_map, padding_mask, kind='both'):
"""Computes spatial bias for grid."""
# The box center is biased to its position on the feature grid:
xy = normalized_grid_corner_coordinates(feature_map, padding_mask)
xy = paddle.clip(xy, 0.0, 1.0)
if kind in ['both', 'location']:
# Unnormalize xy (i.e., apply logit function/sigmoid^-1).
xy_bias = logit(xy)
else:
xy_bias = paddle.zeros_like(xy)
if kind in ['both', 'size']:
# The box size is biased to the patch size:
wh_bias = logit(paddle.full_like(xy_bias, 1.0 / feature_map.shape[-1]))
else:
wh_bias = paddle.zeros_like(xy_bias)
return paddle.concat([xy_bias, wh_bias], axis=-1)
def logit(x, eps=1e-4):
"""Logit (inverse sigmoid) function (https://en.wikipedia.org/wiki/Logit)."""
return paddle.log(x + eps) - paddle.log1p(-x + eps)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册