提交 2ca2e71b 编写于 作者: C channingss

update onnx symbolic shape inference

上级 3aa8e577
...@@ -346,8 +346,12 @@ class ONNXGraph(Graph): ...@@ -346,8 +346,12 @@ class ONNXGraph(Graph):
#if len(value_info['shape']) == 0 or value_info[ #if len(value_info['shape']) == 0 or value_info[
# 'dtype'] is None or 0 in value_info['shape']: # 'dtype'] is None or 0 in value_info['shape']:
# #TODO add node shape inference # #TODO add node shape inference
shape = value_info['shape']
for idx in range(len(shape)):
if shape[idx] == 0:
shape[idx] = -1
node.out_shapes.append(shape)
node.dtype = value_info['dtype'] node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else: else:
node.out_shapes.append([]) node.out_shapes.append([])
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# Reference Code from https://github.com/microsoft/onnxruntime, Licensed under the MIT License. # Reference Code from https://github.com/microsoft/onnxruntime, Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import argparse import argparse
import numpy as np import numpy as np
import onnx import onnx
...@@ -23,7 +22,6 @@ from onnx import helper, numpy_helper, shape_inference ...@@ -23,7 +22,6 @@ from onnx import helper, numpy_helper, shape_inference
import sympy import sympy
from packaging import version from packaging import version
assert version.parse(onnx.__version__) >= version.parse("1.5.0")
def get_attribute(node, attr_name, default_value=None): def get_attribute(node, attr_name, default_value=None):
...@@ -45,17 +43,15 @@ def get_shape_from_type_proto(type_proto): ...@@ -45,17 +43,15 @@ def get_shape_from_type_proto(type_proto):
def get_shape_from_sympy_shape(sympy_shape): def get_shape_from_sympy_shape(sympy_shape):
sympy_shape = [ return [
None if i is None else (int(i) if is_literal(i) else str(i)) None if i is None else (int(i) if is_literal(i) else str(i))
for i in sympy_shape for i in sympy_shape
] ]
return sympy_shape
def is_literal(dim): def is_literal(dim):
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or ( return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(
hasattr(dim, 'is_number') and dim, 'is_number') and dim.is_number)
dim.is_number) # or (hasattr(dim, 'is_integer') and dim.is_integer)
def handle_negative_axis(axis, rank): def handle_negative_axis(axis, rank):
...@@ -119,6 +115,7 @@ class SymbolicShapeInference: ...@@ -119,6 +115,7 @@ class SymbolicShapeInference:
'Div': self._infer_symbolic_compute_ops, 'Div': self._infer_symbolic_compute_ops,
'Expand': self._infer_Expand, 'Expand': self._infer_Expand,
'Equal': self._infer_symbolic_compute_ops, 'Equal': self._infer_symbolic_compute_ops,
'Floor': self._infer_symbolic_compute_ops,
'Gather': self._infer_Gather, 'Gather': self._infer_Gather,
'GatherElements': self._infer_GatherElements, 'GatherElements': self._infer_GatherElements,
'GatherND': self._infer_GatherND, 'GatherND': self._infer_GatherND,
...@@ -145,13 +142,13 @@ class SymbolicShapeInference: ...@@ -145,13 +142,13 @@ class SymbolicShapeInference:
'Size': self._infer_Size, 'Size': self._infer_Size,
'Slice': self._infer_Slice, 'Slice': self._infer_Slice,
'Split': self._infer_Split, 'Split': self._infer_Split,
'SplitToSequence': self._infer_SplitToSequence,
'Squeeze': self._infer_Squeeze, 'Squeeze': self._infer_Squeeze,
'Sub': self._infer_symbolic_compute_ops, 'Sub': self._infer_symbolic_compute_ops,
'Tile': self._infer_Tile, 'Tile': self._infer_Tile,
'TopK': self._infer_TopK, 'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze, 'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops, 'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap 'ZipMap': self._infer_ZipMap
} }
self.run_ = True self.run_ = True
...@@ -267,8 +264,9 @@ class SymbolicShapeInference: ...@@ -267,8 +264,9 @@ class SymbolicShapeInference:
if pending_nodes and self.verbose_ > 0: if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ') print('SymbolicShapeInference: orphaned nodes discarded: ')
for n in pending_nodes: print(
print(n.op_type + ': ' + n.output[0]) *[n.op_type + ': ' + n.output[0] for n in pending_nodes],
sep='\n')
if input_shapes is not None: if input_shapes is not None:
for input_name, shape in input_shapes.items(): for input_name, shape in input_shapes.items():
...@@ -280,6 +278,7 @@ class SymbolicShapeInference: ...@@ -280,6 +278,7 @@ class SymbolicShapeInference:
helper.make_tensor_value_info( helper.make_tensor_value_info(
value_info.name, value_info.name,
value_info.type.tensor_type.elem_type, shape)) value_info.type.tensor_type.elem_type, shape))
self.initializers_ = dict( self.initializers_ = dict(
[(i.name, i) for i in self.out_mp_.graph.initializer]) [(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict( self.known_vi_ = dict(
...@@ -351,21 +350,11 @@ class SymbolicShapeInference: ...@@ -351,21 +350,11 @@ class SymbolicShapeInference:
def _get_shape(self, node, idx): def _get_shape(self, node, idx):
name = node.input[idx] name = node.input[idx]
shape = []
if name in self.known_vi_: if name in self.known_vi_:
shape = get_shape_from_type_proto(self.known_vi_[name].type) return get_shape_from_type_proto(self.known_vi_[name].type)
elif name in self.initializers_:
assert name in self.initializers_
shape = list(self.initializers_[name].dims)
return shape
def _get_initializer_value(self, node, idx):
name = node.input[idx]
if name in self.initializers_:
value = numpy_helper.to_array(self.initializers_[name])
return value
else: else:
return False assert name in self.initializers_
return list(self.initializers_[name].dims)
def _get_shape_rank(self, node, idx): def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx)) return len(self._get_shape(node, idx))
...@@ -373,7 +362,7 @@ class SymbolicShapeInference: ...@@ -373,7 +362,7 @@ class SymbolicShapeInference:
def _get_sympy_shape(self, node, idx): def _get_sympy_shape(self, node, idx):
sympy_shape = [] sympy_shape = []
for d in self._get_shape(node, idx): for d in self._get_shape(node, idx):
if type(d) is str: if type(d) == str:
sympy_shape.append(self.symbolic_dims_[d] if d in sympy_shape.append(self.symbolic_dims_[d] if d in
self.symbolic_dims_ else sympy.Symbol( self.symbolic_dims_ else sympy.Symbol(
d, integer=True)) d, integer=True))
...@@ -416,10 +405,14 @@ class SymbolicShapeInference: ...@@ -416,10 +405,14 @@ class SymbolicShapeInference:
# run single node inference with self.known_vi_ shapes # run single node inference with self.known_vi_ shapes
# note that inference rely on initializer values is not handled # note that inference rely on initializer values is not handled
# as we don't copy initializer weights to tmp_graph for inference speed purpose # as we don't copy initializer weights to tmp_graph for inference speed purpose
if node.op_type == 'SplitToSequence':
make_value_info_func = helper.make_sequence_value_info
else:
make_value_info_func = helper.make_tensor_value_info
tmp_graph = helper.make_graph( tmp_graph = helper.make_graph(
[node], 'tmp', [self.known_vi_[i] for i in node.input if i], [ [node], 'tmp', [self.known_vi_[i] for i in node.input if i], [
helper.make_tensor_value_info(i, onnx.TensorProto.UNDEFINED, make_value_info_func(i, onnx.TensorProto.UNDEFINED, None)
None) for i in node.output for i in node.output
]) ])
self.tmp_mp_.graph.CopyFrom(tmp_graph) self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
...@@ -465,6 +458,12 @@ class SymbolicShapeInference: ...@@ -465,6 +458,12 @@ class SymbolicShapeInference:
self.verbose_) self.verbose_)
all_shapes_inferred = False all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_) symbolic_shape_inference._preprocess(self.tmp_mp_)
# note that after _preprocess, Constant node will be converted to initializer and should be appended to subgraph.initializer
subgraph.initializer.extend([
i for i in symbolic_shape_inference.out_mp_.graph.initializer
if i.name not in subgraph_implicit_input and i.name not in
subgraph_inputs
])
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_: while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl( all_shapes_inferred = symbolic_shape_inference._infer_impl(
...@@ -533,18 +532,13 @@ class SymbolicShapeInference: ...@@ -533,18 +532,13 @@ class SymbolicShapeInference:
assert len(node.output) == 1 assert len(node.output) == 1
values = self._get_int_values(node, broadcast=True) values = self._get_int_values(node, broadcast=True)
if all([v is not None for v in values]): if all([v is not None for v in values]):
new_shape = []
is_list = [type(v) == list for v in values] is_list = [type(v) == list for v in values]
as_list = any(is_list) as_list = any(is_list)
if as_list: if as_list:
data = [op_func(vs) for vs in zip(*values)] self.sympy_data_[node.output[
self.sympy_data_[node.output[0]] = data 0]] = [op_func(vs) for vs in zip(*values)]
new_shape = np.array(data).shape
else: else:
data = op_func(values) self.sympy_data_[node.output[0]] = op_func(values)
self.sympy_data_[node.output[0]] = data
new_shape = np.array(data).shape
vi = self.known_vi_[node.output[0]]
def _pass_on_sympy_data(self, node): def _pass_on_sympy_data(self, node):
assert len(node.input) == 1 or node.op_type == 'Reshape' assert len(node.input) == 1 or node.op_type == 'Reshape'
...@@ -677,8 +671,8 @@ class SymbolicShapeInference: ...@@ -677,8 +671,8 @@ class SymbolicShapeInference:
lhs_reduce_dim = -1 lhs_reduce_dim = -1
rhs_reduce_dim = -2 rhs_reduce_dim = -2
new_shape = self._broadcast_shapes( new_shape = self._broadcast_shapes(
lhs_shape[:-2], rhs_shape[:-2]) + [lhs_shape[-2] lhs_shape[:-2],
] + [rhs_shape[-1]] rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]]
# merge reduce dim # merge reduce dim
self._check_merged_dims( self._check_merged_dims(
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
...@@ -706,6 +700,7 @@ class SymbolicShapeInference: ...@@ -706,6 +700,7 @@ class SymbolicShapeInference:
'Add': lambda l: l[0] + l[1], 'Add': lambda l: l[0] + l[1],
'Div': lambda l: l[0] // l[1], # integer div in sympy 'Div': lambda l: l[0] // l[1], # integer div in sympy
'Equal': lambda l: l[0] == l[1], 'Equal': lambda l: l[0] == l[1],
'Floor': lambda l: sympy.floor(l[0]),
'Max': 'Max':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min': 'Min':
...@@ -731,15 +726,6 @@ class SymbolicShapeInference: ...@@ -731,15 +726,6 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], output_type, helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0))) self._get_shape(node, 0)))
def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node): def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0) input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output # create a new symbolic dimension for Compress output
...@@ -851,11 +837,11 @@ class SymbolicShapeInference: ...@@ -851,11 +837,11 @@ class SymbolicShapeInference:
axis = handle_negative_axis( axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(data_shape)) get_attribute(node, 'axis', 0), len(data_shape))
indices_shape = self._get_shape(node, 1) indices_shape = self._get_shape(node, 1)
new_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:]
vi = self.known_vi_[node.output[0]] vi = self.known_vi_[node.output[0]]
vi.CopyFrom( vi.CopyFrom(
helper.make_tensor_value_info(node.output[ helper.make_tensor_value_info(
0], vi.type.tensor_type.elem_type, new_shape)) node.output[0], vi.type.tensor_type.elem_type, data_shape[:axis]
+ indices_shape + data_shape[axis + 1:]))
if node.input[0] in self.sympy_data_: if node.input[0] in self.sympy_data_:
assert 0 == get_attribute(node, 'axis', assert 0 == get_attribute(node, 'axis',
0) # only handle 1D sympy compute 0) # only handle 1D sympy compute
...@@ -863,9 +849,8 @@ class SymbolicShapeInference: ...@@ -863,9 +849,8 @@ class SymbolicShapeInference:
data = self.sympy_data_[node.input[0]] data = self.sympy_data_[node.input[0]]
if type(data) == list: if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1: if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[0]] = [ self.sympy_data_[node.output[
data[int(i)] for i in idx 0]] = [data[int(i)] for i in idx]
]
else: else:
self.sympy_data_[node.output[0]] = data[int(idx)] self.sympy_data_[node.output[0]] = data[int(idx)]
else: else:
...@@ -896,8 +881,8 @@ class SymbolicShapeInference: ...@@ -896,8 +881,8 @@ class SymbolicShapeInference:
def _infer_If(self, node): def _infer_If(self, node):
# special case for constant condition, in case there are mismatching shape from the non-executed branch # special case for constant condition, in case there are mismatching shape from the non-executed branch
subgraphs = [ subgraphs = [
get_attribute(node, 'then_branch'), get_attribute(node, 'then_branch'), get_attribute(node,
get_attribute(node, 'else_branch') 'else_branch')
] ]
cond = self._try_get_value(node, 0) cond = self._try_get_value(node, 0)
if cond is not None: if cond is not None:
...@@ -976,11 +961,14 @@ class SymbolicShapeInference: ...@@ -976,11 +961,14 @@ class SymbolicShapeInference:
0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) 0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
def _infer_OneHot(self, node): def _infer_OneHot(self, node):
shape = self._get_shape(node, 0) sympy_shape = self._get_sympy_shape(node, 0)
depth = self._try_get_value(node, 1)
axis = get_attribute(node, 'axis', -1) axis = get_attribute(node, 'axis', -1)
axis = handle_negative_axis(axis, len(shape) + 1) axis = handle_negative_axis(axis, len(sympy_shape) + 1)
new_shape = shape[:axis] + [self._new_symbolic_dim_from_output(node) new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [
] + shape[axis:] self._new_symbolic_dim_from_output(node)
if not is_literal(depth) else depth
] + sympy_shape[axis:])
vi = self.known_vi_[node.output[0]] vi = self.known_vi_[node.output[0]]
vi.CopyFrom( vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[ helper.make_tensor_value_info(node.output[0], self.known_vi_[
...@@ -1125,11 +1113,16 @@ class SymbolicShapeInference: ...@@ -1125,11 +1113,16 @@ class SymbolicShapeInference:
sympy.simplify(sympy.floor(s)) for s in sizes sympy.simplify(sympy.floor(s)) for s in sizes
] ]
self._update_computed_dims(new_sympy_shape) self._update_computed_dims(new_sympy_shape)
elif roi is not None and scales is not None: elif scales is not None:
rank = len(scales) rank = len(scales)
assert len(roi) == 2 * rank if get_attribute(node, 'coordinate_transformation_mode'
roi_start = list(roi)[:rank] ) == 'tf_crop_and_resize':
roi_end = list(roi)[rank:] assert len(roi) == 2 * rank
roi_start = list(roi)[:rank]
roi_end = list(roi)[rank:]
else:
roi_start = [0] * rank
roi_end = [1] * rank
scales = list(scales) scales = list(scales)
new_sympy_shape = [ new_sympy_shape = [
sympy.simplify(sympy.floor(d * (end - start) * scale)) sympy.simplify(sympy.floor(d * (end - start) * scale))
...@@ -1265,8 +1258,8 @@ class SymbolicShapeInference: ...@@ -1265,8 +1258,8 @@ class SymbolicShapeInference:
e = new_sympy_shape[i] e = new_sympy_shape[i]
except Exception: except Exception:
print( print(
'Unable to determine if {} <= {}, treat as equal' 'Unable to determine if {} <= {}, treat as equal'.
.format(e, new_sympy_shape[i])) format(e, new_sympy_shape[i]))
e = new_sympy_shape[i] e = new_sympy_shape[i]
if is_literal(s) and int(s) < 0: if is_literal(s) and int(s) < 0:
...@@ -1290,7 +1283,7 @@ class SymbolicShapeInference: ...@@ -1290,7 +1283,7 @@ class SymbolicShapeInference:
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][ self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][
starts[0]:ends[0]] starts[0]:ends[0]]
def _infer_Split(self, node): def _infer_Split_Common(self, node, make_value_info_func):
input_sympy_shape = self._get_sympy_shape(node, 0) input_sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis( axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(input_sympy_shape)) get_attribute(node, 'axis', 0), len(input_sympy_shape))
...@@ -1306,14 +1299,20 @@ class SymbolicShapeInference: ...@@ -1306,14 +1299,20 @@ class SymbolicShapeInference:
for i_o in range(len(split)): for i_o in range(len(split)):
vi = self.known_vi_[node.output[i_o]] vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom( vi.CopyFrom(
helper.make_tensor_value_info( make_value_info_func(node.output[i_o], self.known_vi_[
node.output[i_o], self.known_vi_[node.input[ node.input[0]].type.tensor_type.elem_type,
0]].type.tensor_type.elem_type, get_shape_from_sympy_shape(
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [ input_sympy_shape[:axis] + [
split[i_o] split[i_o]
] + input_sympy_shape[axis + 1:]))) ] + input_sympy_shape[axis + 1:])))
self.known_vi_[vi.name] = vi self.known_vi_[vi.name] = vi
def _infer_Split(self, node):
self._infer_Split_Common(node, helper.make_tensor_value_info)
def _infer_SplitToSequence(self, node):
self._infer_Split_Common(node, helper.make_sequence_value_info)
def _infer_Squeeze(self, node): def _infer_Squeeze(self, node):
self._pass_on_sympy_data(node) self._pass_on_sympy_data(node)
...@@ -1416,6 +1415,14 @@ class SymbolicShapeInference: ...@@ -1416,6 +1415,14 @@ class SymbolicShapeInference:
self._onnx_infer_single_node(node) self._onnx_infer_single_node(node)
if node.op_type in self.dispatcher_: if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node) self.dispatcher_[node.op_type](node)
elif node.op_type in ['ConvTranspose']:
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
# before adding symbolic compute for them
# mark the output type as UNDEFINED to allow guessing of rank
vi = self.known_vi_[node.output[0]]
if len(vi.type.tensor_type.shape.dim) == 0:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
if self.verbose_ > 2: if self.verbose_ > 2:
print(node.op_type + ': ' + node.name) print(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input): for i, name in enumerate(node.input):
...@@ -1443,6 +1450,7 @@ class SymbolicShapeInference: ...@@ -1443,6 +1450,7 @@ class SymbolicShapeInference:
] ]
if len(in_dims) > 1: if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True) self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)): for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]] vi = self.known_vi_[node.output[i_o]]
out_type = vi.type out_type = vi.type
...@@ -1473,16 +1481,22 @@ class SymbolicShapeInference: ...@@ -1473,16 +1481,22 @@ class SymbolicShapeInference:
if node.op_type in [ if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16' 'MatMul', 'MatMulInteger', 'MatMulInteger16'
]: ]:
# only support auto merge for MatMul for dim < rank-2 when rank > 2 if None in out_shape:
assert len(shapes[0]) > 2 and dim_idx[0] < len( idx = out_shape.index(None)
shapes[0]) - 2 dim_idx = [
assert len(shapes[1]) > 2 and dim_idx[1] < len( len(s) - len(out_shape) + idx
shapes[1]) - 2 for s in shapes
]
# only support auto merge for MatMul for dim < rank-2 when rank > 2
assert len(shapes[0]) > 2 and dim_idx[
0] < len(shapes[0]) - 2
assert len(shapes[1]) > 2 and dim_idx[
1] < len(shapes[1]) - 2
elif node.op_type == 'Expand': elif node.op_type == 'Expand':
# auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
shapes = [ shapes = [
self._get_shape(node, 0), self._get_shape(node, 0), self._get_value(node,
self._get_value(node, 1) 1)
] ]
else: else:
shapes = [] shapes = []
...@@ -1531,9 +1545,9 @@ class SymbolicShapeInference: ...@@ -1531,9 +1545,9 @@ class SymbolicShapeInference:
if self.verbose_ > 0: if self.verbose_ > 0:
if is_unknown_op: if is_unknown_op:
print( print(
"Possible unknown op: {} node: {}, guessing {} shape" "Possible unknown op: {} node: {}, guessing {} shape".
.format(node.op_type, node.name, format(node.op_type, node.name,
vi.name)) vi.name))
if self.verbose_ > 2: if self.verbose_ > 2:
print(' {}: {} {}'.format( print(' {}: {} {}'.format(
node.output[i_o], node.output[i_o],
...@@ -1544,7 +1558,7 @@ class SymbolicShapeInference: ...@@ -1544,7 +1558,7 @@ class SymbolicShapeInference:
continue # continue the inference after guess, no need to stop as no merge is needed continue # continue the inference after guess, no need to stop as no merge is needed
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
print('Stopping at incomplete symbolic shape inference at ' + print('Stopping at incomplete shape inference at ' +
node.op_type + ': ' + node.name) node.op_type + ': ' + node.name)
print('node inputs:') print('node inputs:')
for i in node.input: for i in node.input:
...@@ -1555,31 +1569,32 @@ class SymbolicShapeInference: ...@@ -1555,31 +1569,32 @@ class SymbolicShapeInference:
if self.auto_merge_ and not out_type_undefined: if self.auto_merge_ and not out_type_undefined:
print('Merging: ' + str(self.suggested_merge_)) print('Merging: ' + str(self.suggested_merge_))
return False return False
self.run_ = False self.run_ = False
return True return True
def _update_output_from_vi(self): def _update_output_from_vi(self):
for output in self.out_mp_.graph.output: for output in self.out_mp_.graph.output:
if output.name in self.known_vi_: if output.name in self.known_vi_:
tmp_output = self.known_vi_[output.name] output.CopyFrom(self.known_vi_[output.name])
output.CopyFrom(tmp_output)
@staticmethod @staticmethod
def infer_shapes(in_mp, def infer_shapes(in_mp,
int_max=2**31 - 1,
fixed_input_shape=None, fixed_input_shape=None,
auto_merge=True, int_max=2**31 - 1,
auto_merge=False,
guess_output_rank=False, guess_output_rank=False,
verbose=0): verbose=0):
if get_opset(in_mp) < 7: assert version.parse(onnx.__version__) >= version.parse("1.5.0")
print('Only support shape inferencing models of opset 7 and above.') onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7:
print('Only support models of onnx opset 7 and above.')
return return
symbolic_shape_inference = SymbolicShapeInference( symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose) int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False all_shapes_inferred = False
symbolic_shape_inference._preprocess( symbolic_shape_inference._preprocess(
in_mp, input_shapes=fixed_input_shape) in_mp, input_shapes=fixed_input_shape)
try: try:
while symbolic_shape_inference.run_: while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl( all_shapes_inferred = symbolic_shape_inference._infer_impl(
...@@ -1592,5 +1607,5 @@ class SymbolicShapeInference: ...@@ -1592,5 +1607,5 @@ class SymbolicShapeInference:
except: except:
print('Stopping at incomplete symbolic shape inference') print('Stopping at incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
in_mp) symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph return symbolic_shape_inference.out_mp_.graph
...@@ -57,6 +57,7 @@ def _is_static_shape(shape): ...@@ -57,6 +57,7 @@ def _is_static_shape(shape):
return False return False
return True return True
def _get_same_padding(in_size, kernel_size, stride): def _get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride)) new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size pad_size = (new_size - 1) * stride + kernel_size - in_size
...@@ -348,6 +349,7 @@ class OpSet9(): ...@@ -348,6 +349,7 @@ class OpSet9():
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear' 'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
) )
fluid_op = 'resize_bilinear' fluid_op = 'resize_bilinear'
attr['align_corners'] = False
node.fluid_code.add_layer( node.fluid_code.add_layer(
fluid_op, inputs=inputs, output=node, param_attr=attr) fluid_op, inputs=inputs, output=node, param_attr=attr)
...@@ -736,53 +738,59 @@ class OpSet9(): ...@@ -736,53 +738,59 @@ class OpSet9():
param_attr=None) param_attr=None)
else: else:
input_inner_indices = node.layer_name + '_input_inner_indices' input_inner_indices = node.layer_name + '_input_inner_indices'
shape = val_x.out_shapes[0]
node.fluid_code.add_layer(
'reshape',
inputs=indices.layer_name,
output=indices.layer_name,
param_attr={'shape': indices.out_shapes[0]})
zeros_like_val_x = val_x.layer_name + '_zeros'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'scatter_nd', 'zeros_like',
inputs=val_x,
output=zeros_like_val_x,
param_attr=None)
node.fluid_code.add_layer(
'scatter_nd_add',
inputs={ inputs={
'shape': val_x.out_shapes[0], 'ref': zeros_like_val_x,
'index': indices, 'index': indices,
'updates': updates 'updates': updates
}, },
output=input_inner_indices, output=input_inner_indices,
param_attr=None) param_attr=None)
indices_mask = node.layer_name + '_indices_mask'
constant_minus_one = node.layer_name + '_constant_minus_one' constant_minus_one = node.layer_name + '_constant_minus_one'
# full_like support create tensor shape like input tensor
node.fluid_code.add_layer( node.fluid_code.add_layer(
'fill_constant', 'full_like',
inputs=None, inputs=updates,
output=constant_minus_one, output=constant_minus_one,
param_attr={ param_attr={'dtype': string(updates.dtype),
'shape': updates.out_shapes[0], 'fill_value': -1})
'dtype': string(updates.dtype),
'value': -1
})
indices_mask = node.layer_name + '_indices_mask'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'scatter_nd', 'scatter_nd_add',
inputs={ inputs={
'shape': val_x.out_shapes[0], 'ref': zeros_like_val_x,
'index': indices, 'index': indices,
'updates': constant_minus_one 'updates': constant_minus_one
}, },
output=indices_mask, output=indices_mask,
param_attr=None) param_attr=None)
constant_one = node.layer_name + '_constant_1'
constant_1 = node.layer_name + '_constant_1' # full_like support create tensor shape like input tensor
node.fluid_code.add_layer( node.fluid_code.add_layer(
'fill_constant', 'full_like',
inputs=None, inputs=val_x,
output=constant_1, output=constant_one,
param_attr={ param_attr={'dtype': string(val_x.dtype),
'shape': val_x.out_shapes[0], 'fill_value': 1})
'dtype': string(val_x.dtype),
'value': 1
})
input_out_indices_mask = node.layer_name + '_input_out_indices_mask' input_out_indices_mask = node.layer_name + '_input_out_indices_mask'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"elementwise_add", "elementwise_add",
inputs={"x": indices_mask, inputs={"x": indices_mask,
"y": constant_1}, "y": constant_one},
output=input_out_indices_mask, output=input_out_indices_mask,
param_attr=None) param_attr=None)
...@@ -841,11 +849,15 @@ class OpSet9(): ...@@ -841,11 +849,15 @@ class OpSet9():
self.omit_nodes.append(ends.layer_name) self.omit_nodes.append(ends.layer_name)
starts_value = starts_value.copy() starts_value = starts_value.copy()
ends_value = ends_value.copy() ends_value = ends_value.copy()
#for idx in range(len(ends_value)):
# if ends_value[idx] > 2**31 - 1:
# ends_value[idx] = 2**31 - 1
#print(val_x.out_shapes)
for idx in range(len(ends_value)): for idx in range(len(ends_value)):
if starts_value[idx] > val_x.out_shapes[0][axes[idx]]: if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
ends_value[idx] = val_x.out_shapes[0][axes[idx]] ends_value[idx] = val_x.out_shapes[0][axes[idx]]
starts_value[idx] = val_x.out_shapes[0][axes[idx]]-1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] > 2**31 - 1: elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1 ends_value[idx] = 2**31 - 1
attr = { attr = {
...@@ -882,10 +894,10 @@ class OpSet9(): ...@@ -882,10 +894,10 @@ class OpSet9():
if steps is not None: if steps is not None:
attr['strides'] = steps attr['strides'] = steps
node.fluid_code.add_layer( node.fluid_code.add_layer(
'strided_slice', inputs=val_x, output=node, param_attr=attr) 'strided_slice', inputs=val_x, output=node, param_attr=attr)
else: else:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr) 'slice', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
def ConstantOfShape(self, node): def ConstantOfShape(self, node):
...@@ -928,15 +940,12 @@ class OpSet9(): ...@@ -928,15 +940,12 @@ class OpSet9():
min_value = _const_weight_or_none(min_ipt) min_value = _const_weight_or_none(min_ipt)
self.omit_nodes.append(max_ipt.layer_name) self.omit_nodes.append(max_ipt.layer_name)
self.omit_nodes.append(min_ipt.layer_name) self.omit_nodes.append(min_ipt.layer_name)
if max_value.shape == (1,): if max_value.shape == (1, ):
max_value = max_value[0] max_value = max_value[0]
if min_value.shape == (1,): if min_value.shape == (1, ):
min_value = min_value[0] min_value = min_value[0]
if max_value is not None and min_value is not None: if max_value is not None and min_value is not None:
attr = { attr = {'max': max_value, 'min': min_value}
'max': max_value,
'min': min_value
}
node.fluid_code.add_layer( node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr) 'clip', inputs=val_x, output=node, param_attr=attr)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册