未验证 提交 f9b0aa8d 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix (#7586)

上级 1c0e4965
...@@ -151,17 +151,24 @@ def main(): ...@@ -151,17 +151,24 @@ def main():
arch_config = config["Architecture"] arch_config = config["Architecture"]
arch_config = config["Architecture"] if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
"name"] != 'MultiHead':
input_shape = config["Eval"]["dataset"]["transforms"][-2][
'SVTRRecResizeImg']['image_shape']
else:
input_shape = None
if arch_config["algorithm"] in ["Distillation", ]: # distillation model if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values()) archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list): for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference") sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx], export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger, quanter) sub_model_save_path, logger, input_shape,
quanter)
else: else:
save_path = os.path.join(save_path, "inference") save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger, quanter) export_single_model(model, arch_config, save_path, logger, input_shape,
quanter)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册