__init__.py 2.2 KB
Newer Older
W
wangxinxin08 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from ppdet.core.workspace import register

__all__ = ['ClipImageTextEmbedder']


@register
class ClipImageTextEmbedder(nn.Layer):
    # This code is based on: https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit
    def __init__(self, base_model, embed_dim, merge_class_token='drop'):
        super().__init__()
        self.clip = base_model
        self.merge_class_token = merge_class_token
        if self.merge_class_token == 'mul-ln':
            self.merged_class_token = nn.LayerNorm(embed_dim)

    def forward(self, images, texts):
        if texts is not None:
            texts_shape = texts.shape
            if len(texts_shape) > 2:
                texts = texts.reshape(-1, texts_shape[-1])

        if images is not None:
            images = normalize_image(images)

        img_emb, txt_emb = self.clip(images, texts, normalize=False)

        if img_emb is not None:
            if self.merge_class_token == 'drop':
                img_emb = img_emb[:, 1:, :]
            elif self.merge_class_token == 'mul-ln':
                img_emb = img_emb[:, :1, :] * img_emb[:, 1:, :]
                img_emb = self.merged_class_token(img_emb)
            else:
                raise ValueError(
                    f'Unknown merge_class_token: {self.merge_class_token}')

        if txt_emb is not None and len(texts_shape) > 2:
            txt_emb = txt_emb.reshape(texts_shape[:-1] + [-1, ])
        return img_emb, txt_emb