提交 c9411445 编写于 作者: W weishengyu

update gallery2fc

上级 8d22f025
import os
import paddle
from ppcls.arch import build_model
from ppcls.arch.gears.identity_head import IdentityHead
from ppcls.utils.config import parse_config, parse_args
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.logger import init_logger
from ppcls.data import transform, create_operators
def load_feature_extractor(configs):
arch = build_model(configs["Arch"])
load_dygraph_pretrain(arch, configs["Global"]["pretrained_model"])
def build_gallery_feature(configs, feature_extractor):
transform_configs = configs["Infer"]["transforms"]
preprocess_ops = create_operators(transform_configs)
embedding_size = configs["Arch"]["Head"]["embedding_size"]
batch_size = configs["IndexProcess"]["batch_size"]
image_shape = configs["Global"]["image_shape"]
image_shape.insert(0, batch_size)
input_tensor = paddle.zeros(image_shape)
image_root = configs["IndexProcess"]["image_root"]
data_file = configs["IndexProcess"]["data_file"]
delimiter = configs["IndexProcess"]["delimiter"]
gallery_images = []
gallery_docs = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
line = ori_line.strip().split(delimiter)
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
def build_gallery_feature(feature_extractor):
pass
def save_fuse_model(fuse_model):
......@@ -21,11 +47,14 @@ def save_fuse_model(fuse_model):
class FuseModel(paddle.nn.Layer):
def __init__(self, configs):
super().__init__()
self.feature_extractor = load_feature_extractor(configs)
self.gallery_layer = build_gallery_feature(self.feature_extractor)
self.feature_extractor = build_model(configs)
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
self.feature_extractor.head = IdentityHead()
self.gallery_layer = build_gallery_feature(configs, self.feature_extractor)
def forward(self, x):
x = self.feature_model(x)
x = self.feature_model(x)["features"]
x = paddle.norm(x)
x = self.gallery_layer(x)
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册