From e142c6a3ab950ac5f8749ad204e6775d92f25728 Mon Sep 17 00:00:00 2001 From: wangxinxin08 Date: Fri, 9 Dec 2022 03:16:56 +0000 Subject: [PATCH] fix some problem --- ppdet/modeling/vl/embedder/clip/clip.py | 25 ++++++++++++++++--------- ppdet/modeling/vl/models/owl_vit.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ppdet/modeling/vl/embedder/clip/clip.py b/ppdet/modeling/vl/embedder/clip/clip.py index 64fadc2d2..8d6d01808 100644 --- a/ppdet/modeling/vl/embedder/clip/clip.py +++ b/ppdet/modeling/vl/embedder/clip/clip.py @@ -82,17 +82,24 @@ class CLIP(nn.Layer): def dtype(self): return self.visual.conv1.weight.dtype - def encode_image(self, image): - return self.visual(image.cast(self.dtype)) + def encode_image(self, image, normalize): + image_features = self.visual(image.cast(self.dtype)) + if normalize: + image_features /= image_features.norm(axis=1, keepdim=True) + return image_features - def encode_text(self, text): - return self.text(text.cast(self.dtype)) + def encode_text(self, text, normalize): + text_features = self.text(text.cast(self.dtype)) + if normalize: + text_features /= text_features.norm(axis=1, keepdim=True) + return text_features def forward(self, image, text, normalize=True): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - if normalize: - image_features /= image_features.norm(axis=1, keepdim=True) - text_features /= image_features.norm(axis=1, keepdim=True) + image_features = text_features = None + if image is not None: + image_features = self.encode_image(image, normalize) + + if text is not None: + text_features = self.encode_text(text, normalize) return image_fetaures, text_features diff --git a/ppdet/modeling/vl/models/owl_vit.py b/ppdet/modeling/vl/models/owl_vit.py index 339394eb5..d388fefaf 100644 --- a/ppdet/modeling/vl/models/owl_vit.py +++ b/ppdet/modeling/vl/models/owl_vit.py @@ -27,7 +27,7 @@ from ..tokenizer import tokenize @register -class OWLViT(BaseArch): +class OWLViT(nn.Layer): __category__ = 'architecture' def __init__(self, embedder, head): -- GitLab