提交 18d99b01 编写于 作者: W weishengyu

move gallery layer into extractor

上级 25edd1c0
...@@ -8,22 +8,22 @@ from ppcls.utils.config import parse_config, parse_args ...@@ -8,22 +8,22 @@ from ppcls.utils.config import parse_config, parse_args
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.data import transform, create_operators from ppcls.data import transform, create_operators
from ppcls.arch.slim import quantize_model
def build_gallery_layer(configs, feature_extractor): class GalleryLayer(paddle.nn.Layer):
transform_configs = configs["IndexProcess"]["transform_ops"] def __init__(self, configs):
preprocess_ops = create_operators(transform_configs) super().__init__()
self.configs = configs
embedding_size = configs["Arch"]["Head"]["embedding_size"] embedding_size = self.configs["Arch"]["Head"]["embedding_size"]
batch_size = configs["IndexProcess"]["batch_size"] self.batch_size = self.configs["IndexProcess"]["batch_size"]
image_shape = configs["Global"]["image_shape"].copy() self.image_shape = self.configs["Global"]["image_shape"].copy()
image_shape.insert(0, batch_size) self.image_shape.insert(0, self.batch_size)
input_tensor = paddle.zeros(image_shape)
image_root = self.configs["IndexProcess"]["image_root"]
image_root = configs["IndexProcess"]["image_root"] data_file = self.configs["IndexProcess"]["data_file"]
data_file = configs["IndexProcess"]["data_file"] delimiter = self.configs["IndexProcess"]["delimiter"]
delimiter = configs["IndexProcess"]["delimiter"] self.gallery_images = []
gallery_images = []
gallery_docs = [] gallery_docs = []
gallery_labels = [] gallery_labels = []
...@@ -35,44 +35,48 @@ def build_gallery_layer(configs, feature_extractor): ...@@ -35,44 +35,48 @@ def build_gallery_layer(configs, feature_extractor):
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0]) image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file) self.gallery_images.append(image_file)
gallery_docs.append(ori_line.strip()) gallery_docs.append(ori_line.strip())
gallery_labels.append(line[1].strip()) gallery_labels.append(line[1].strip())
self.gallery_layer = paddle.nn.Linear(embedding_size, len(self.gallery_images), bias_attr=False)
def forward(self, x):
x = paddle.nn.functional.normalize(x)
x = self.gallery_layer(x)
return x
def build_gallery_layer(self, feature_extractor):
transform_configs = self.configs["IndexProcess"]["transform_ops"]
preprocess_ops = create_operators(transform_configs)
embedding_size = self.configs["Arch"]["Head"]["embedding_size"]
batch_index = 0 batch_index = 0
gallery_feature = paddle.zeros((len(gallery_images), embedding_size)) input_tensor = paddle.zeros(self.image_shape)
for i, image_path in enumerate(gallery_images): gallery_feature = paddle.zeros((len(self.gallery_images), embedding_size))
for i, image_path in enumerate(self.gallery_images):
image = cv2.imread(image_path) image = cv2.imread(image_path)
for op in preprocess_ops: for op in preprocess_ops:
image = op(image) image = op(image)
input_tensor[batch_index] = image input_tensor[batch_index] = image
batch_index += 1 batch_index += 1
if batch_index == batch_size or i == len(gallery_images) - 1: if batch_index == self.batch_size or i == len(self.gallery_images) - 1:
batch_feature = feature_extractor(input_tensor)["features"] batch_feature = feature_extractor(input_tensor)["features"]
for j in range(batch_index): for j in range(batch_index):
feature = batch_feature[j] feature = batch_feature[j]
norm_feature = paddle.nn.functional.normalize(feature, axis=0) norm_feature = paddle.nn.functional.normalize(feature, axis=0)
gallery_feature[i - batch_index + j] = norm_feature gallery_feature[i - batch_index + j] = norm_feature
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False) self.gallery_layer.set_state_dict({"weight": gallery_feature.T})
gallery_layer.set_state_dict({"weight": gallery_feature.T})
return gallery_layer
class GalleryLayer(paddle.nn.Layer):
def __init__(self, configs, feature_extractor):
super().__init__()
self.gallery_layer = build_gallery_layer(configs, feature_extractor)
def forward(self, x):
x = paddle.nn.functional.normalize(x)
x = self.gallery_layer(x)
return x
def export_fuse_model(configs): def export_fuse_model(configs):
slim_config = configs["Slim"].copy()
configs["Slim"] = None
fuse_model = build_model(configs) fuse_model = build_model(configs)
fuse_model.head = GalleryLayer(configs)
configs["slim"] = slim_config
quantize_model(configs, fuse_model)
load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"]) load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"])
fuse_model.eval() fuse_model.eval()
fuse_model.head = GalleryLayer(configs, fuse_model) fuse_model.head.build_gallery_layer(fuse_model)
save_path = configs["Global"]["save_inference_dir"] save_path = configs["Global"]["save_inference_dir"]
fuse_model.quanter.save_quantized_model( fuse_model.quanter.save_quantized_model(
fuse_model, fuse_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册