diff --git a/ppdet/modeling/vl/embedder/clip/clip.py b/ppdet/modeling/vl/embedder/clip/clip.py index 64fadc2d26303034bab17f5e2c0fb8e971f86a7f..8d6d01808556a0fead4936711b87884fb2d29405 100644 --- a/ppdet/modeling/vl/embedder/clip/clip.py +++ b/ppdet/modeling/vl/embedder/clip/clip.py @@ -82,17 +82,24 @@ class CLIP(nn.Layer): def dtype(self): return self.visual.conv1.weight.dtype - def encode_image(self, image): - return self.visual(image.cast(self.dtype)) + def encode_image(self, image, normalize): + image_features = self.visual(image.cast(self.dtype)) + if normalize: + image_features /= image_features.norm(axis=1, keepdim=True) + return image_features - def encode_text(self, text): - return self.text(text.cast(self.dtype)) + def encode_text(self, text, normalize): + text_features = self.text(text.cast(self.dtype)) + if normalize: + text_features /= text_features.norm(axis=1, keepdim=True) + return text_features def forward(self, image, text, normalize=True): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - if normalize: - image_features /= image_features.norm(axis=1, keepdim=True) - text_features /= image_features.norm(axis=1, keepdim=True) + image_features = text_features = None + if image is not None: + image_features = self.encode_image(image, normalize) + + if text is not None: + text_features = self.encode_text(text, normalize) return image_fetaures, text_features diff --git a/ppdet/modeling/vl/models/owl_vit.py b/ppdet/modeling/vl/models/owl_vit.py index 339394eb5c44e638d81bdcbb2e08debf8035789b..d388fefaf9a37c7e0d781637db1abd4ff0207081 100644 --- a/ppdet/modeling/vl/models/owl_vit.py +++ b/ppdet/modeling/vl/models/owl_vit.py @@ -27,7 +27,7 @@ from ..tokenizer import tokenize @register -class OWLViT(BaseArch): +class OWLViT(nn.Layer): __category__ = 'architecture' def __init__(self, embedder, head):