gallery2fc.py 2.2 KB
Newer Older
W
weishengyu 已提交
1
import os
W
weishengyu 已提交
2
import paddle
W
weishengyu 已提交
3

W
weishengyu 已提交
4
from ppcls.arch import build_model
W
weishengyu 已提交
5
from ppcls.arch.gears.identity_head import IdentityHead
W
weishengyu 已提交
6
from ppcls.utils.config import parse_config, parse_args
W
dbg  
weishengyu 已提交
7
from ppcls.utils.save_load import load_dygraph_pretrain
W
weishengyu 已提交
8
from ppcls.utils.logger import init_logger
W
weishengyu 已提交
9
from ppcls.data import transform, create_operators
W
dbg  
weishengyu 已提交
10 11


W
weishengyu 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
def build_gallery_feature(configs, feature_extractor):
    transform_configs = configs["Infer"]["transforms"]
    preprocess_ops = create_operators(transform_configs)

    embedding_size = configs["Arch"]["Head"]["embedding_size"]
    batch_size = configs["IndexProcess"]["batch_size"]
    image_shape = configs["Global"]["image_shape"]
    image_shape.insert(0, batch_size)
    input_tensor = paddle.zeros(image_shape)

    image_root = configs["IndexProcess"]["image_root"]
    data_file = configs["IndexProcess"]["data_file"]
    delimiter = configs["IndexProcess"]["delimiter"]
    gallery_images = []
    gallery_docs = []

    with open(data_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for _, ori_line in enumerate(lines):
            line = ori_line.strip().split(delimiter)
            text_num = len(line)
            assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
            image_file = os.path.join(image_root, line[0])

            gallery_images.append(image_file)
            gallery_docs.append(ori_line.strip())

W
weishengyu 已提交
39 40 41 42 43 44 45 46 47 48 49




def save_fuse_model(fuse_model):
    pass


class FuseModel(paddle.nn.Layer):
    def __init__(self, configs):
        super().__init__()
W
weishengyu 已提交
50 51 52 53
        self.feature_extractor = build_model(configs)
        load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
        self.feature_extractor.head = IdentityHead()
        self.gallery_layer = build_gallery_feature(configs, self.feature_extractor)
W
weishengyu 已提交
54 55

    def forward(self, x):
W
weishengyu 已提交
56 57
        x = self.feature_model(x)["features"]
        x = paddle.norm(x)
W
weishengyu 已提交
58 59 60 61 62 63 64
        x = self.gallery_layer(x)
        return x


def main():
    args = parse_args()
    configs = parse_config(args.config)
W
weishengyu 已提交
65
    init_logger(name='gallery2fc')
W
weishengyu 已提交
66 67 68 69 70 71
    fuse_model = FuseModel(configs)
    save_fuse_model(fuse_model)


if __name__ == '__main__':
    main()