From d5637367a90961f2563b6eb1f0398a6786177405 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 23 Dec 2021 20:35:10 +0800 Subject: [PATCH] add export method --- ppcls/utils/gallery2fc.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/ppcls/utils/gallery2fc.py b/ppcls/utils/gallery2fc.py index 224a1078..3b8d2814 100644 --- a/ppcls/utils/gallery2fc.py +++ b/ppcls/utils/gallery2fc.py @@ -57,18 +57,6 @@ def build_gallery_layer(configs, feature_extractor): return gallery_layer -def export_fuse_model(model, config): - model.eval() - model.quanter.save_quantized_model( - model.base_model, - save_path, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + config["Global"]["image_shape"], - dtype='float32') - ]) - - class FuseModel(paddle.nn.Layer): def __init__(self, configs): super().__init__() @@ -85,12 +73,25 @@ class FuseModel(paddle.nn.Layer): return x +def export_fuse_model(configs): + fuse_model = FuseModel(configs) + fuse_model.eval() + save_path = configs["Global"]["save_inference_dir"] + fuse_model.quanter.save_quantized_model( + fuse_model, + save_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + configs["Global"]["image_shape"], + dtype='float32') + ]) + + def main(): args = parse_args() configs = parse_config(args.config) init_logger(name='gallery2fc') - fuse_model = FuseModel(configs) - # save_fuse_model(fuse_model) + export_fuse_model(configs) if __name__ == '__main__': -- GitLab