diff --git a/ppdet/modeling/vl/embedder/__init__.py b/ppdet/modeling/vl/embedder/__init__.py index 9e28baadf500a8490731a153fec09f393e21d02e..864638732f68854dece0b79f909ffcdb78146ae1 100644 --- a/ppdet/modeling/vl/embedder/__init__.py +++ b/ppdet/modeling/vl/embedder/__init__.py @@ -21,6 +21,7 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register +from .clip import * __all__ = ['ClipImageTextEmbedder'] diff --git a/ppdet/modeling/vl/embedder/clip/__init__.py b/ppdet/modeling/vl/embedder/clip/__init__.py index 4cf4e7bf758a12d59e294a821dfba4c95e77c61a..185d771fe508cc0a994eb75c9bd5198e52c132fd 100644 --- a/ppdet/modeling/vl/embedder/clip/__init__.py +++ b/ppdet/modeling/vl/embedder/clip/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .models import ModifiedResNet, TextEncoder, VisionTransformer +from .models import ModifiedResNet, TextEncoder, ViT from .layers import LayerNorm, QuickGELU, AttentionPool2D from .clip import CLIP diff --git a/ppdet/modeling/vl/embedder/clip/clip.py b/ppdet/modeling/vl/embedder/clip/clip.py index 8d6d01808556a0fead4936711b87884fb2d29405..6e15f142e32ca656802e01b55b2fa8d2d09970b1 100644 --- a/ppdet/modeling/vl/embedder/clip/clip.py +++ b/ppdet/modeling/vl/embedder/clip/clip.py @@ -31,7 +31,7 @@ from ppdet.modeling.layers import MultiHeadAttention from ppdet.modeling.initializer import zeros_, normal_ from ppdet.core.workspace import register -from .models import ModifiedResNet, VisionTransformer, TextEncoder +from .models import ModifiedResNet, ViT, TextEncoder @register diff --git a/ppdet/modeling/vl/embedder/clip/layers.py b/ppdet/modeling/vl/embedder/clip/layers.py index fca8c8815ff81e7ef738dab255bd5614236e9e65..eee7eb50dd01d6a8137d7c45b87328e1b9e627ee 100644 --- a/ppdet/modeling/vl/embedder/clip/layers.py +++ b/ppdet/modeling/vl/embedder/clip/layers.py @@ -84,7 +84,7 @@ class Bottleneck(nn.Layer): return out -class AttentionPool2D(nn.Module): +class AttentionPool2D(nn.Layer): def __init__(self, spacial_dim, embed_dim, num_heads, output_dim): super().__init__() # TODO: need check whether it is consistent with torch or not @@ -151,10 +151,9 @@ class ResidualAttentionBlock(nn.Layer): self.attn = MultiHeadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential( - OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ( - "gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model) - )])) + self.mlp = nn.Sequential(("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model))) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask self.droplayer_p = droplayer_p @@ -192,6 +191,7 @@ class Transformer(nn.Layer): super().__init__() self.width = width self.layers = layers + self.stochastic_droplayer_rate = stochastic_droplayer_rate blocks = [] for i in range(self.layers): droplayer_p = (i / max(self.layers - 1, diff --git a/ppdet/modeling/vl/embedder/clip/models.py b/ppdet/modeling/vl/embedder/clip/models.py index 49ee8d007b8641a132f048e8257745705e4bf931..d4af77feeafbdf552e8ecadd956b816d3353c274 100644 --- a/ppdet/modeling/vl/embedder/clip/models.py +++ b/ppdet/modeling/vl/embedder/clip/models.py @@ -32,7 +32,7 @@ from ppdet.core.workspace import register from .layers import * -__all__ = ['ModifiedResNet', 'VisionTransformer', 'TextEncoder'] +__all__ = ['ModifiedResNet', 'ViT', 'TextEncoder'] @register @@ -105,7 +105,7 @@ class ModifiedResNet(nn.Layer): @register -class VisionTransformer(nn.Layer): +class ViT(nn.Layer): def __init__(self, input_resolution, patch_size, @@ -115,6 +115,7 @@ class VisionTransformer(nn.Layer): output_dim=None, stochastic_droplayer_rate=0.0): super().__init__() + self.width = width self.input_resolution = input_resolution self.output_dim = output_dim self.conv1 = nn.Conv2D( @@ -122,7 +123,7 @@ class VisionTransformer(nn.Layer): out_channels=width, kernel_size=patch_size, stride=patch_size, - bias=False) + bias_attr=False) scale = width**-0.5 self.class_embedding = self.create_parameter( shape=[width], attr=ParamAttr(initializer=Normal(std=scale))) @@ -157,9 +158,14 @@ class VisionTransformer(nn.Layer): @register class TextEncoder(nn.Layer): - def __init__(self, context_length, vocab_size, transformer_width, - transformer_heads, transformer_layers, - stochastic_droplayer_rate): + def __init__(self, + embed_dim, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + stochastic_droplayer_rate=0.0): super().__init__() self.context_length = context_length @@ -178,8 +184,6 @@ class TextEncoder(nn.Layer): self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Linear( transformer_width, embed_dim, bias_attr=False) - self.logit_scale = self.create_parameter( - shape=[], attr=ParamAttr(initializer=Constant(np.log(1. / 0.07)))) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens diff --git a/ppdet/modeling/vl/head/__init__.py b/ppdet/modeling/vl/head/__init__.py index 97043fd7ba6885aac81cad5a49924c23c67d4d47..2de8a9ab61147b42358120ac71142e77a7584187 100644 --- a/ppdet/modeling/vl/head/__init__.py +++ b/ppdet/modeling/vl/head/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .owl_vit_head import * \ No newline at end of file diff --git a/ppdet/modeling/vl/head/owl_vit_head.py b/ppdet/modeling/vl/head/owl_vit_head.py index 5607443298080d4199984a46692f2146e8605c8b..fafa20f7d427f49c089e1ca7cd2568a8a9a935e9 100644 --- a/ppdet/modeling/vl/head/owl_vit_head.py +++ b/ppdet/modeling/vl/head/owl_vit_head.py @@ -22,6 +22,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from ppdet.modeling.ops import get_act_fn +from ppdet.core.workspace import register from ..utils import compute_box_bias @@ -46,12 +47,13 @@ class PredictorMLP(nn.Layer): in_dim, out_dim, num_layers, - mlp_dim, - hidden_activation, + mlp_dim=None, + hidden_activation='gelu', out_activation=None): super().__init__() layers = [] + mlp_dim = in_dim if mlp_dim is None else mlp_dim for _ in range(num_layers - 1): layers.append(nn.Linear(in_dim, mlp_dim)) in_dim = mlp_dim @@ -138,7 +140,6 @@ class OWLViTHead(nn.Layer): self.class_head = class_head self.bbox_head = bbox_head self.box_bias = box_bias - self.matcher = matcher self.loss = loss def box_predictor(self, image_features, feature_map): diff --git a/ppdet/modeling/vl/loss/__init__.py b/ppdet/modeling/vl/loss/__init__.py index 97043fd7ba6885aac81cad5a49924c23c67d4d47..1b6789619223407179292691b9c4bcae1707c100 100644 --- a/ppdet/modeling/vl/loss/__init__.py +++ b/ppdet/modeling/vl/loss/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .owl_vit_loss import * \ No newline at end of file diff --git a/ppdet/modeling/vl/loss/owl_vit_loss.py b/ppdet/modeling/vl/loss/owl_vit_loss.py index b5fdfd92fae2fb5b6d63c6063dfcfbc5980fc145..0d11c44f2843311e170cc5f87626ddacf49ab01b 100644 --- a/ppdet/modeling/vl/loss/owl_vit_loss.py +++ b/ppdet/modeling/vl/loss/owl_vit_loss.py @@ -32,7 +32,7 @@ class OWLViTLoss(nn.Layer): __inject__ = ['HungarianMatcher'] def __init__(self, - num_classes, + num_classes=80, matcher='HungarianMatcher', normalization='per_example', loss_coeff=None, diff --git a/ppdet/modeling/vl/tokenizer/simple_tokenizer.py b/ppdet/modeling/vl/tokenizer/simple_tokenizer.py index 723da452d10a61b129ae2b391cf8a49fc5b416a6..f82b7b31527aa8d2e8370a3cda6fb7de220ea6df 100644 --- a/ppdet/modeling/vl/tokenizer/simple_tokenizer.py +++ b/ppdet/modeling/vl/tokenizer/simple_tokenizer.py @@ -21,6 +21,7 @@ from __future__ import print_function import gzip import html import os +import functools from functools import lru_cache import ftfy