From 823074d8940ba7d28c455680fe097c685761166c Mon Sep 17 00:00:00 2001 From: weishengyu Date: Wed, 8 Dec 2021 19:56:30 +0800 Subject: [PATCH] dbg --- ppcls/utils/gallery2fc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ppcls/utils/gallery2fc.py b/ppcls/utils/gallery2fc.py index e02a7e4f..660297f6 100644 --- a/ppcls/utils/gallery2fc.py +++ b/ppcls/utils/gallery2fc.py @@ -1,6 +1,12 @@ import paddle from ppcls.arch import build_model from ppcls.utils.config import parse_config, parse_args +from ppcls.utils.save_load import load_dygraph_pretrain + + +def load_feature_extractor(configs): + arch = build_model(configs["Arch"]) + load_dygraph_pretrain(arch, configs["Global"]["pretrained_model"]) def build_gallery_feature(feature_extractor): @@ -14,7 +20,7 @@ def save_fuse_model(fuse_model): class FuseModel(paddle.nn.Layer): def __init__(self, configs): super().__init__() - self.feature_extractor = build_model(configs) + self.feature_extractor = load_feature_extractor(configs) self.gallery_layer = build_gallery_feature(self.feature_extractor) def forward(self, x): -- GitLab