diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index 98bb4214c1fd3b6de4d7041449dc0d770cbe77fc..0fbcc51034738c0fab1304fe05d51a61c8039eb6 100644 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -18,7 +18,7 @@ from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference from onnx.checker import ValidationError from onnx.checker import check_model from onnx.utils import polish_model -from onnx import helper +from onnx import helper, shape_inference from onnx.helper import get_attribute_value, make_attribute from onnx.shape_inference import infer_shapes from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE @@ -141,6 +141,10 @@ class ONNXGraph(Graph): print("shape inferencing ...") self.graph = SymbolicShapeInference.infer_shapes( onnx_model, fixed_input_shape=self.fixed_input_shape) + if self.graph is None: + print('[WARNING] Shape inference by ONNX offical interface.') + onnx_model = shape_inference.infer_shapes(onnx_model) + self.graph = onnx_model.graph print("shape inferenced.") self.build() self.collect_value_infos() diff --git a/x2paddle/decoder/onnx_shape_inference.py b/x2paddle/decoder/onnx_shape_inference.py index 7b08be9db4729a047123dc59b471e74288d94a8a..987fab290848efaafa2dc42e1b389f37ec2b978e 100644 --- a/x2paddle/decoder/onnx_shape_inference.py +++ b/x2paddle/decoder/onnx_shape_inference.py @@ -1588,7 +1588,7 @@ class SymbolicShapeInference: assert version.parse(onnx.__version__) >= version.parse("1.5.0") onnx_opset = get_opset(in_mp) if not onnx_opset or onnx_opset < 7: - print('Only support models of onnx opset 7 and above.') + print('[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.') return symbolic_shape_inference = SymbolicShapeInference( int_max, auto_merge, guess_output_rank, verbose) @@ -1601,11 +1601,11 @@ class SymbolicShapeInference: in_mp) symbolic_shape_inference._update_output_from_vi() if not all_shapes_inferred: - print('!' * 10) symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_) + print('[INFO] Complete symbolic shape inference.') except: - print('Stopping at incomplete symbolic shape inference') + print('[WARNING] Incomplete symbolic shape inference.') symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_) return symbolic_shape_inference.out_mp_.graph