未验证 提交 eeb538dc 编写于 作者: J Jason 提交者: GitHub

Merge pull request #430 from Channingss/fix_bug

fix bug of symbolic shape inference
...@@ -18,7 +18,7 @@ from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference ...@@ -18,7 +18,7 @@ from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference
from onnx.checker import ValidationError from onnx.checker import ValidationError
from onnx.checker import check_model from onnx.checker import check_model
from onnx.utils import polish_model from onnx.utils import polish_model
from onnx import helper from onnx import helper, shape_inference
from onnx.helper import get_attribute_value, make_attribute from onnx.helper import get_attribute_value, make_attribute
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
...@@ -141,6 +141,10 @@ class ONNXGraph(Graph): ...@@ -141,6 +141,10 @@ class ONNXGraph(Graph):
print("shape inferencing ...") print("shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes( self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape) onnx_model, fixed_input_shape=self.fixed_input_shape)
if self.graph is None:
print('[WARNING] Shape inference by ONNX offical interface.')
onnx_model = shape_inference.infer_shapes(onnx_model)
self.graph = onnx_model.graph
print("shape inferenced.") print("shape inferenced.")
self.build() self.build()
self.collect_value_infos() self.collect_value_infos()
......
...@@ -1588,7 +1588,7 @@ class SymbolicShapeInference: ...@@ -1588,7 +1588,7 @@ class SymbolicShapeInference:
assert version.parse(onnx.__version__) >= version.parse("1.5.0") assert version.parse(onnx.__version__) >= version.parse("1.5.0")
onnx_opset = get_opset(in_mp) onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7: if not onnx_opset or onnx_opset < 7:
print('Only support models of onnx opset 7 and above.') print('[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.')
return return
symbolic_shape_inference = SymbolicShapeInference( symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose) int_max, auto_merge, guess_output_rank, verbose)
...@@ -1601,11 +1601,11 @@ class SymbolicShapeInference: ...@@ -1601,11 +1601,11 @@ class SymbolicShapeInference:
in_mp) in_mp)
symbolic_shape_inference._update_output_from_vi() symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred: if not all_shapes_inferred:
print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) symbolic_shape_inference.out_mp_)
print('[INFO] Complete symbolic shape inference.')
except: except:
print('Stopping at incomplete symbolic shape inference') print('[WARNING] Incomplete symbolic shape inference.')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph return symbolic_shape_inference.out_mp_.graph
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册