diff --git a/tools/validate.py b/tools/validate.py index 139a3ee9150f0cf31ea48514403befb1fad957d6..5a9a34972aa6891295948551fcf348f99d90d1fb 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -21,10 +21,6 @@ import re import common -import onnx -from onnx import helper -from onnx import TensorProto - # Validation Flow: # 1. Generate input data # 2. Use mace_run to run model on phone. @@ -198,6 +194,7 @@ def validate_onnx_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, output_names, output_shapes, validation_threshold, input_data_types, backend): + import onnx if backend == "tensorflow": from onnx_tf.backend import prepare print "valivate on onnx tensorflow backend." @@ -228,9 +225,9 @@ def validate_onnx_model(platform, device_type, model_file, input_file, out_shape[1], out_shape[2], out_shape[3] = \ out_shape[3], out_shape[1], out_shape[2] onnx_outputs.append( - helper.make_tensor_value_info(output_names[i], - TensorProto.FLOAT, - out_shape)) + onnx.helper.make_tensor_value_info(output_names[i], + onnx.TensorProto.FLOAT, + out_shape)) model.graph.output.extend(onnx_outputs) rep = prepare(model)