diff --git a/ppcls/utils/gallery2fc.py b/ppcls/utils/gallery2fc.py index 224a10786b6fcd56cabb2d0d35b4b923b4d79083..3b8d281482a10ce092f7293f5ef75ab3f1d8c2d7 100644 --- a/ppcls/utils/gallery2fc.py +++ b/ppcls/utils/gallery2fc.py @@ -57,18 +57,6 @@ def build_gallery_layer(configs, feature_extractor): return gallery_layer -def export_fuse_model(model, config): - model.eval() - model.quanter.save_quantized_model( - model.base_model, - save_path, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + config["Global"]["image_shape"], - dtype='float32') - ]) - - class FuseModel(paddle.nn.Layer): def __init__(self, configs): super().__init__() @@ -85,12 +73,25 @@ class FuseModel(paddle.nn.Layer): return x +def export_fuse_model(configs): + fuse_model = FuseModel(configs) + fuse_model.eval() + save_path = configs["Global"]["save_inference_dir"] + fuse_model.quanter.save_quantized_model( + fuse_model, + save_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + configs["Global"]["image_shape"], + dtype='float32') + ]) + + def main(): args = parse_args() configs = parse_config(args.config) init_logger(name='gallery2fc') - fuse_model = FuseModel(configs) - # save_fuse_model(fuse_model) + export_fuse_model(configs) if __name__ == '__main__':