提交 a19614d8 编写于 作者: C Channingss

update

上级 d83371a2
...@@ -329,11 +329,29 @@ class BaseAPI: ...@@ -329,11 +329,29 @@ class BaseAPI:
"Model for inference deploy saved in {}.".format(save_dir)) "Model for inference deploy saved in {}.".format(save_dir))
def export_onnx_model(self, save_dir, onnx_name=None): def export_onnx_model(self, save_dir, onnx_name=None):
from fluid.utils import op_io_info, init_name_prefix support_list = ['ResNet18','ResNet34','ResNet50','ResNet101','ResNet50_vd',
from onnx import helper, checker 'ResNet101_vd','ResNet50_vd_ssld','ResNet101_vd_ssld','DarkNet53',
import fluid_onnx.ops as ops 'MobileNetV1','MobileNetV2','MobileNetV3_large','MobileNetV3_small',
from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight 'MobileNetV3_large_ssld','MobileNetV3_small_ssld','Xception41',
from debug.model_check import debug_model, Tracker 'Xception65','DenseNet121','DenseNet161','DenseNet201','ShuffleNetV2']
unsupport_list = []
if self.model_type in unsupport_list:
raise Exception("Model: {} unsupport export to ONNX"
.format(self.model_type)
try:
from fluid.utils import op_io_info, init_name_prefix
from onnx import helper, checker
import fluid_onnx.ops as ops
from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
from debug.model_check import debug_model, Tracker
except Exception as e:
print(e)
print(
"Import Module Failed! Please install paddle2onnx. Related requirements
see https://github.com/PaddlePaddle/paddle2onnx"
)
sys.exit(-1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
inference_scope = fluid.global_scope() inference_scope = fluid.global_scope()
...@@ -392,6 +410,7 @@ class BaseAPI: ...@@ -392,6 +410,7 @@ class BaseAPI:
op_check_list.append(op.type) op_check_list.append(op.type)
print('The operator sets to run test case.') print('The operator sets to run test case.')
print(set(op_check_list)) print(set(op_check_list))
# Create outputs # Create outputs
# Get the new names for outputs if they've been renamed in nodes' making # Get the new names for outputs if they've been renamed in nodes' making
renamed_outputs = op_io_info.get_all_renamed_outputs() renamed_outputs = op_io_info.get_all_renamed_outputs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册