gallery2fc.py 3.5 KB
Newer Older
W
weishengyu 已提交
1
import os
W
weishengyu 已提交
2
import paddle
W
weishengyu 已提交
3
import cv2
W
weishengyu 已提交
4

W
weishengyu 已提交
5
from ppcls.arch import build_model
W
weishengyu 已提交
6
from ppcls.arch.gears.identity_head import IdentityHead
W
weishengyu 已提交
7
from ppcls.utils.config import parse_config, parse_args
W
dbg  
weishengyu 已提交
8
from ppcls.utils.save_load import load_dygraph_pretrain
W
weishengyu 已提交
9
from ppcls.utils.logger import init_logger
W
weishengyu 已提交
10
from ppcls.data import transform, create_operators
W
dbg  
weishengyu 已提交
11 12


W
weishengyu 已提交
13 14
def build_gallery_layer(configs, feature_extractor):
    transform_configs = configs["IndexProcess"]["transform_ops"]
W
weishengyu 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27
    preprocess_ops = create_operators(transform_configs)

    embedding_size = configs["Arch"]["Head"]["embedding_size"]
    batch_size = configs["IndexProcess"]["batch_size"]
    image_shape = configs["Global"]["image_shape"]
    image_shape.insert(0, batch_size)
    input_tensor = paddle.zeros(image_shape)

    image_root = configs["IndexProcess"]["image_root"]
    data_file = configs["IndexProcess"]["data_file"]
    delimiter = configs["IndexProcess"]["delimiter"]
    gallery_images = []
    gallery_docs = []
W
dbg  
weishengyu 已提交
28
    gallery_labels = []
W
weishengyu 已提交
29 30 31

    with open(data_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
W
dbg  
weishengyu 已提交
32
        for ori_line in lines:
W
weishengyu 已提交
33 34 35 36 37 38 39
            line = ori_line.strip().split(delimiter)
            text_num = len(line)
            assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
            image_file = os.path.join(image_root, line[0])

            gallery_images.append(image_file)
            gallery_docs.append(ori_line.strip())
W
dbg  
weishengyu 已提交
40
            gallery_labels.append(line[1].strip())
W
weishengyu 已提交
41 42 43 44 45 46 47 48 49
    batch_index = 0
    gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
    for i, image_path in enumerate(gallery_images):
        image = cv2.imread(image_path)
        for op in preprocess_ops:
            image = op(image)
        input_tensor[batch_index] = image
        batch_index += 1
        if batch_index == batch_size or i == len(gallery_images) - 1:
W
dbg  
weishengyu 已提交
50
            batch_feature = feature_extractor(input_tensor)["features"]
W
weishengyu 已提交
51 52
            for j in range(batch_index):
                feature = batch_feature[j]
W
dbg  
weishengyu 已提交
53 54 55 56
                norm_feature = paddle.nn.functional.normalize(feature, axis=0)
                gallery_feature[i - batch_index + j] = norm_feature
    gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False)
    gallery_layer.set_state_dict({"weight": gallery_feature.T})
W
weishengyu 已提交
57 58 59
    return gallery_layer


W
weishengyu 已提交
60 61 62
class FuseModel(paddle.nn.Layer):
    def __init__(self, configs):
        super().__init__()
W
weishengyu 已提交
63 64
        self.feature_extractor = build_model(configs)
        load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
W
dbg  
weishengyu 已提交
65
        self.feature_extractor.eval()
W
weishengyu 已提交
66
        self.feature_extractor.head = IdentityHead()
W
weishengyu 已提交
67
        self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
W
weishengyu 已提交
68 69

    def forward(self, x):
W
dbg  
weishengyu 已提交
70
        x = self.feature_extractor(x)["features"]
W
weishengyu 已提交
71
        x = paddle.nn.functional.normalize(x)
W
weishengyu 已提交
72 73 74 75
        x = self.gallery_layer(x)
        return x


W
weishengyu 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89
def export_fuse_model(configs):
    fuse_model = FuseModel(configs)
    fuse_model.eval()
    save_path = configs["Global"]["save_inference_dir"]
    fuse_model.quanter.save_quantized_model(
        fuse_model,
        save_path,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None] + configs["Global"]["image_shape"],
                dtype='float32')
        ])


W
weishengyu 已提交
90 91 92
def main():
    args = parse_args()
    configs = parse_config(args.config)
W
weishengyu 已提交
93
    init_logger(name='gallery2fc')
W
weishengyu 已提交
94
    export_fuse_model(configs)
W
weishengyu 已提交
95 96 97 98


if __name__ == '__main__':
    main()