提交 e142c6a3 编写于 作者: W wangxinxin08

fix some problem

上级 27395ac8
......@@ -82,17 +82,24 @@ class CLIP(nn.Layer):
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.cast(self.dtype))
def encode_image(self, image, normalize):
image_features = self.visual(image.cast(self.dtype))
if normalize:
image_features /= image_features.norm(axis=1, keepdim=True)
return image_features
def encode_text(self, text):
return self.text(text.cast(self.dtype))
def encode_text(self, text, normalize):
text_features = self.text(text.cast(self.dtype))
if normalize:
text_features /= text_features.norm(axis=1, keepdim=True)
return text_features
def forward(self, image, text, normalize=True):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
if normalize:
image_features /= image_features.norm(axis=1, keepdim=True)
text_features /= image_features.norm(axis=1, keepdim=True)
image_features = text_features = None
if image is not None:
image_features = self.encode_image(image, normalize)
if text is not None:
text_features = self.encode_text(text, normalize)
return image_fetaures, text_features
......@@ -27,7 +27,7 @@ from ..tokenizer import tokenize
@register
class OWLViT(BaseArch):
class OWLViT(nn.Layer):
__category__ = 'architecture'
def __init__(self, embedder, head):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册