提交 25edd1c0 编写于 作者: W weishengyu

refactor code

上级 d5637367
...@@ -16,7 +16,7 @@ def build_gallery_layer(configs, feature_extractor): ...@@ -16,7 +16,7 @@ def build_gallery_layer(configs, feature_extractor):
embedding_size = configs["Arch"]["Head"]["embedding_size"] embedding_size = configs["Arch"]["Head"]["embedding_size"]
batch_size = configs["IndexProcess"]["batch_size"] batch_size = configs["IndexProcess"]["batch_size"]
image_shape = configs["Global"]["image_shape"] image_shape = configs["Global"]["image_shape"].copy()
image_shape.insert(0, batch_size) image_shape.insert(0, batch_size)
input_tensor = paddle.zeros(image_shape) input_tensor = paddle.zeros(image_shape)
...@@ -57,25 +57,22 @@ def build_gallery_layer(configs, feature_extractor): ...@@ -57,25 +57,22 @@ def build_gallery_layer(configs, feature_extractor):
return gallery_layer return gallery_layer
class FuseModel(paddle.nn.Layer): class GalleryLayer(paddle.nn.Layer):
def __init__(self, configs): def __init__(self, configs, feature_extractor):
super().__init__() super().__init__()
self.feature_extractor = build_model(configs) self.gallery_layer = build_gallery_layer(configs, feature_extractor)
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
self.feature_extractor.eval()
self.feature_extractor.head = IdentityHead()
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
def forward(self, x): def forward(self, x):
x = self.feature_extractor(x)["features"]
x = paddle.nn.functional.normalize(x) x = paddle.nn.functional.normalize(x)
x = self.gallery_layer(x) x = self.gallery_layer(x)
return x return x
def export_fuse_model(configs): def export_fuse_model(configs):
fuse_model = FuseModel(configs) fuse_model = build_model(configs)
load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"])
fuse_model.eval() fuse_model.eval()
fuse_model.head = GalleryLayer(configs, fuse_model)
save_path = configs["Global"]["save_inference_dir"] save_path = configs["Global"]["save_inference_dir"]
fuse_model.quanter.save_quantized_model( fuse_model.quanter.save_quantized_model(
fuse_model, fuse_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册