From f9b0aa8dc9f7d18d1a125476d2361d65963e0eb4 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 14 Sep 2022 10:26:32 +0800 Subject: [PATCH] fix (#7586) --- deploy/slim/quantization/export_model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index fd1c3e5e..bd132b62 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -151,17 +151,24 @@ def main(): arch_config = config["Architecture"] - arch_config = config["Architecture"] + if arch_config["algorithm"] == "SVTR" and arch_config["Head"][ + "name"] != 'MultiHead': + input_shape = config["Eval"]["dataset"]["transforms"][-2][ + 'SVTRRecResizeImg']['image_shape'] + else: + input_shape = None if arch_config["algorithm"] in ["Distillation", ]: # distillation model archs = list(arch_config["Models"].values()) for idx, name in enumerate(model.model_name_list): sub_model_save_path = os.path.join(save_path, name, "inference") export_single_model(model.model_list[idx], archs[idx], - sub_model_save_path, logger, quanter) + sub_model_save_path, logger, input_shape, + quanter) else: save_path = os.path.join(save_path, "inference") - export_single_model(model, arch_config, save_path, logger, quanter) + export_single_model(model, arch_config, save_path, logger, input_shape, + quanter) if __name__ == "__main__": -- GitLab