gallery2fc.py 4.7 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
weishengyu 已提交
15
import os
W
weishengyu 已提交
16
import paddle
W
weishengyu 已提交
17
import cv2
W
weishengyu 已提交
18

W
weishengyu 已提交
19
from ppcls.arch import build_model
W
weishengyu 已提交
20
from ppcls.utils.config import parse_config, parse_args
W
dbg  
weishengyu 已提交
21
from ppcls.utils.save_load import load_dygraph_pretrain
W
weishengyu 已提交
22
from ppcls.utils.logger import init_logger
W
weishengyu 已提交
23
from ppcls.data import create_operators
W
weishengyu 已提交
24
from ppcls.arch.slim import quantize_model
W
weishengyu 已提交
25 26


W
weishengyu 已提交
27
class GalleryLayer(paddle.nn.Layer):
W
weishengyu 已提交
28
    def __init__(self, configs):
W
weishengyu 已提交
29
        super().__init__()
W
weishengyu 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.configs = configs
        embedding_size = self.configs["Arch"]["Head"]["embedding_size"]
        self.batch_size = self.configs["IndexProcess"]["batch_size"]
        self.image_shape = self.configs["Global"]["image_shape"].copy()
        self.image_shape.insert(0, self.batch_size)

        image_root = self.configs["IndexProcess"]["image_root"]
        data_file = self.configs["IndexProcess"]["data_file"]
        delimiter = self.configs["IndexProcess"]["delimiter"]
        self.gallery_images = []
        gallery_docs = []
        gallery_labels = []

        with open(data_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            for ori_line in lines:
                line = ori_line.strip().split(delimiter)
                text_num = len(line)
                assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
                image_file = os.path.join(image_root, line[0])

                self.gallery_images.append(image_file)
                gallery_docs.append(ori_line.strip())
                gallery_labels.append(line[1].strip())
        self.gallery_layer = paddle.nn.Linear(embedding_size, len(self.gallery_images), bias_attr=False)
W
dbg  
weishengyu 已提交
55
        self.gallery_layer.skip_quant = True
W
weishengyu 已提交
56 57 58 59 60 61
        output_label_str = ""
        for i, label_i in enumerate(gallery_labels):
            output_label_str += "{} {}\n".format(i, label_i)
        output_path = configs["Global"]["save_inference_dir"] + "_label.txt"
        with open(output_path, "w") as f:
            f.write(output_label_str)
W
weishengyu 已提交
62

W
dbg  
weishengyu 已提交
63
    def forward(self, x, label=None):
W
weishengyu 已提交
64
        x = paddle.nn.functional.normalize(x)
W
weishengyu 已提交
65 66 67
        x = self.gallery_layer(x)
        return x

W
weishengyu 已提交
68 69 70 71 72 73 74 75
    def build_gallery_layer(self, feature_extractor):
        transform_configs = self.configs["IndexProcess"]["transform_ops"]
        preprocess_ops = create_operators(transform_configs)
        embedding_size = self.configs["Arch"]["Head"]["embedding_size"]
        batch_index = 0
        input_tensor = paddle.zeros(self.image_shape)
        gallery_feature = paddle.zeros((len(self.gallery_images), embedding_size))
        for i, image_path in enumerate(self.gallery_images):
W
dbg  
weishengyu 已提交
76
            image = cv2.imread(image_path)[:, :, ::-1]
W
weishengyu 已提交
77 78 79 80 81 82 83 84 85
            for op in preprocess_ops:
                image = op(image)
            input_tensor[batch_index] = image
            batch_index += 1
            if batch_index == self.batch_size or i == len(self.gallery_images) - 1:
                batch_feature = feature_extractor(input_tensor)["features"]
                for j in range(batch_index):
                    feature = batch_feature[j]
                    norm_feature = paddle.nn.functional.normalize(feature, axis=0)
W
dbg  
weishengyu 已提交
86 87
                    gallery_feature[i - batch_index + j + 1] = norm_feature
        self.gallery_layer.set_state_dict({"_layer.weight": gallery_feature.T})
W
weishengyu 已提交
88

W
weishengyu 已提交
89

W
weishengyu 已提交
90
def export_fuse_model(configs):
W
weishengyu 已提交
91 92
    slim_config = configs["Slim"].copy()
    configs["Slim"] = None
W
weishengyu 已提交
93
    fuse_model = build_model(configs)
W
weishengyu 已提交
94
    fuse_model.head = GalleryLayer(configs)
W
dbg  
weishengyu 已提交
95
    configs["Slim"] = slim_config
W
weishengyu 已提交
96
    quantize_model(configs, fuse_model)
W
weishengyu 已提交
97
    load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"])
W
weishengyu 已提交
98
    fuse_model.eval()
W
weishengyu 已提交
99
    fuse_model.head.build_gallery_layer(fuse_model)
W
weishengyu 已提交
100 101 102 103 104 105 106 107 108 109 110
    save_path = configs["Global"]["save_inference_dir"]
    fuse_model.quanter.save_quantized_model(
        fuse_model,
        save_path,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None] + configs["Global"]["image_shape"],
                dtype='float32')
        ])


W
weishengyu 已提交
111 112 113
def main():
    args = parse_args()
    configs = parse_config(args.config)
W
weishengyu 已提交
114
    init_logger(name='gallery2fc')
W
weishengyu 已提交
115
    export_fuse_model(configs)
W
weishengyu 已提交
116 117 118 119


if __name__ == '__main__':
    main()