From a19614d882a40a9eaa5389ce5bd148be99268629 Mon Sep 17 00:00:00 2001 From: Channingss Date: Thu, 14 May 2020 03:29:45 +0000 Subject: [PATCH] update --- paddlex/cv/models/base.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index a034eeb..6cfd64e 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -329,11 +329,29 @@ class BaseAPI: "Model for inference deploy saved in {}.".format(save_dir)) def export_onnx_model(self, save_dir, onnx_name=None): - from fluid.utils import op_io_info, init_name_prefix - from onnx import helper, checker - import fluid_onnx.ops as ops - from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight - from debug.model_check import debug_model, Tracker + support_list = ['ResNet18','ResNet34','ResNet50','ResNet101','ResNet50_vd', + 'ResNet101_vd','ResNet50_vd_ssld','ResNet101_vd_ssld','DarkNet53', + 'MobileNetV1','MobileNetV2','MobileNetV3_large','MobileNetV3_small', + 'MobileNetV3_large_ssld','MobileNetV3_small_ssld','Xception41', + 'Xception65','DenseNet121','DenseNet161','DenseNet201','ShuffleNetV2'] + unsupport_list = [] + if self.model_type in unsupport_list: + raise Exception("Model: {} unsupport export to ONNX" + .format(self.model_type) + try: + from fluid.utils import op_io_info, init_name_prefix + from onnx import helper, checker + import fluid_onnx.ops as ops + from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight + from debug.model_check import debug_model, Tracker + except Exception as e: + print(e) + print( + "Import Module Failed! Please install paddle2onnx. Related requirements + see https://github.com/PaddlePaddle/paddle2onnx" + ) + sys.exit(-1) + place = fluid.CPUPlace() exe = fluid.Executor(place) inference_scope = fluid.global_scope() @@ -392,6 +410,7 @@ class BaseAPI: op_check_list.append(op.type) print('The operator sets to run test case.') print(set(op_check_list)) + # Create outputs # Get the new names for outputs if they've been renamed in nodes' making renamed_outputs = op_io_info.get_all_renamed_outputs() -- GitLab