提交 e142c6a3 编写于 作者: W wangxinxin08

fix some problem

上级 27395ac8
...@@ -82,17 +82,24 @@ class CLIP(nn.Layer): ...@@ -82,17 +82,24 @@ class CLIP(nn.Layer):
def dtype(self): def dtype(self):
return self.visual.conv1.weight.dtype return self.visual.conv1.weight.dtype
def encode_image(self, image): def encode_image(self, image, normalize):
return self.visual(image.cast(self.dtype)) image_features = self.visual(image.cast(self.dtype))
if normalize:
image_features /= image_features.norm(axis=1, keepdim=True)
return image_features
def encode_text(self, text): def encode_text(self, text, normalize):
return self.text(text.cast(self.dtype)) text_features = self.text(text.cast(self.dtype))
if normalize:
text_features /= text_features.norm(axis=1, keepdim=True)
return text_features
def forward(self, image, text, normalize=True): def forward(self, image, text, normalize=True):
image_features = self.encode_image(image) image_features = text_features = None
text_features = self.encode_text(text) if image is not None:
if normalize: image_features = self.encode_image(image, normalize)
image_features /= image_features.norm(axis=1, keepdim=True)
text_features /= image_features.norm(axis=1, keepdim=True) if text is not None:
text_features = self.encode_text(text, normalize)
return image_fetaures, text_features return image_fetaures, text_features
...@@ -27,7 +27,7 @@ from ..tokenizer import tokenize ...@@ -27,7 +27,7 @@ from ..tokenizer import tokenize
@register @register
class OWLViT(BaseArch): class OWLViT(nn.Layer):
__category__ = 'architecture' __category__ = 'architecture'
def __init__(self, embedder, head): def __init__(self, embedder, head):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册