提交 135eb45e 编写于 作者: C Channingss

update

上级 a99c32db
...@@ -168,8 +168,10 @@ class ONNXGraph(Graph): ...@@ -168,8 +168,10 @@ class ONNXGraph(Graph):
print('shape inferencing ...') print('shape inferencing ...')
infered_graph = SymbolicShapeInference.infer_shapes( infered_graph = SymbolicShapeInference.infer_shapes(
self.model, fixed_input_shape=self.fixed_input_shape) self.model, fixed_input_shape=self.fixed_input_shape)
#infered_graph = None
if infered_graph is None: if infered_graph is None:
infered_model = shape_inference.infer_shapes(self.model) infered_model = shape_inference.infer_shapes(self.model)
onnx.save(infered_model, 'infered_model.onnx')
self.graph = infered_model.graph self.graph = infered_model.graph
else: else:
self.graph = infered_graph self.graph = infered_graph
...@@ -196,15 +198,21 @@ class ONNXGraph(Graph): ...@@ -196,15 +198,21 @@ class ONNXGraph(Graph):
except: except:
shape = input( shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ") "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ")
if shape.count("-1") > 1: #if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)") # print("Only 1 dimension can be -1, type again:)")
else: #else:
right_shape_been_input = True right_shape_been_input = True
if shape == 'N': if shape == 'N':
break return
shape = [int(dim) for dim in shape.strip().split(',')] #shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1" shape_ = []
self.fixed_input_shape[vi.name] = shape for dim in shape.strip().split(','):
if dim.isdigit():
shape_.append(int(dim))
else:
shape_.append(dim)
#assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape_
def get_place_holder_nodes(self): def get_place_holder_nodes(self):
""" """
......
...@@ -151,7 +151,7 @@ class SymbolicShapeInference: ...@@ -151,7 +151,7 @@ class SymbolicShapeInference:
'TopK': self._infer_TopK, 'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze, 'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops, 'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose, #'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap 'ZipMap': self._infer_ZipMap
} }
self.run_ = True self.run_ = True
...@@ -731,14 +731,16 @@ class SymbolicShapeInference: ...@@ -731,14 +731,16 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], output_type, helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0))) self._get_shape(node, 0)))
def _infer_Transpose(self, node): #def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0) # input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm') # perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist() # output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]] # print(input_shape)
vi.CopyFrom( # print(out_shape)
helper.make_tensor_value_info(node.output[0], self.known_vi_[ # vi = self.known_vi_[node.output[0]]
node.input[0]].type.tensor_type.elem_type, output_shape)) # vi.CopyFrom(
# helper.make_tensor_value_info(node.output[0], self.known_vi_[
# node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node): def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0) input_shape = self._get_shape(node, 0)
...@@ -856,21 +858,28 @@ class SymbolicShapeInference: ...@@ -856,21 +858,28 @@ class SymbolicShapeInference:
vi.CopyFrom( vi.CopyFrom(
helper.make_tensor_value_info(node.output[ helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, new_shape)) 0], vi.type.tensor_type.elem_type, new_shape))
#if node.output[0] == '173':
# print('yyyy')
if node.input[0] in self.sympy_data_: if node.input[0] in self.sympy_data_:
assert 0 == get_attribute(node, 'axis', assert 0 == get_attribute(node, 'axis',
0) # only handle 1D sympy compute 0) # only handle 1D sympy compute
idx = self._get_value(node, 1) idx = self._get_value(node, 1)
data = self.sympy_data_[node.input[0]] data = self.sympy_data_[node.input[0]]
print(data)
print(node.output[0])
if type(data) == list: if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1: if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[0]] = [ self.sympy_data_[node.output[0]] = [
data[int(i)] for i in idx data[int(i)] for i in idx
] ]
else: else:
print(node.output[0], 'else')
print(int(idx))
self.sympy_data_[node.output[0]] = data[int(idx)] self.sympy_data_[node.output[0]] = data[int(idx)]
else: else:
assert idx == 0 assert idx == 0
self.sympy_data_[node.output[0]] = data self.sympy_data_[node.output[0]] = data
def _infer_GatherElements(self, node): def _infer_GatherElements(self, node):
indices_shape = self._get_shape(node, 1) indices_shape = self._get_shape(node, 1)
...@@ -1419,7 +1428,7 @@ class SymbolicShapeInference: ...@@ -1419,7 +1428,7 @@ class SymbolicShapeInference:
if self.verbose_ > 2: if self.verbose_ > 2:
print(node.op_type + ': ' + node.name) print(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input): for i, name in enumerate(node.input):
print(' Input {}: {} {}€5€5€5€5€5'.format( print(' Input {}: {} {}'.format(
i, name, 'initializer' i, name, 'initializer'
if name in self.initializers_ else '')) if name in self.initializers_ else ''))
...@@ -1458,7 +1467,6 @@ class SymbolicShapeInference: ...@@ -1458,7 +1467,6 @@ class SymbolicShapeInference:
if node.output[i_o] in self.sympy_data_: if node.output[i_o] in self.sympy_data_:
print(' Sympy Data: ' + str(self.sympy_data_[ print(' Sympy Data: ' + str(self.sympy_data_[
node.output[i_o]])) node.output[i_o]]))
if None in out_shape or out_type_undefined: if None in out_shape or out_type_undefined:
if self.auto_merge_: if self.auto_merge_:
if node.op_type in [ if node.op_type in [
...@@ -1570,7 +1578,7 @@ class SymbolicShapeInference: ...@@ -1570,7 +1578,7 @@ class SymbolicShapeInference:
fixed_input_shape=None, fixed_input_shape=None,
auto_merge=True, auto_merge=True,
guess_output_rank=False, guess_output_rank=False,
verbose=0): verbose=3):
if get_opset(in_mp) < 7: if get_opset(in_mp) < 7:
print('Only support shape inferencing models of opset 7 and above.') print('Only support shape inferencing models of opset 7 and above.')
return return
...@@ -1578,16 +1586,16 @@ class SymbolicShapeInference: ...@@ -1578,16 +1586,16 @@ class SymbolicShapeInference:
int_max, auto_merge, guess_output_rank, verbose) int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False all_shapes_inferred = False
symbolic_shape_inference._preprocess( symbolic_shape_inference._preprocess(
in_mp, input_shapes=fixed_input_shape) in_mp, input_shapes=fixed_input_shape)
try: try:
while symbolic_shape_inference.run_: while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl( all_shapes_inferred = symbolic_shape_inference._infer_impl(
in_mp) in_mp)
symbolic_shape_inference._update_output_from_vi() symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred: #if not all_shapes_inferred:
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes( # symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_) # symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx') onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except: except:
return None pass
return symbolic_shape_inference.out_mp_.graph return symbolic_shape_inference.out_mp_.graph
...@@ -257,7 +257,7 @@ class OpSet9(): ...@@ -257,7 +257,7 @@ class OpSet9():
shape = node.out_shapes[0] shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape): for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0: if dim_shape == 0 and i == 0:
shape[i] = 1 shape[i] = -1
if dim_shape == 0 and i != 0: if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned' assert 'shape of input is not assigned'
attr = { attr = {
...@@ -1142,19 +1142,21 @@ class OpSet9(): ...@@ -1142,19 +1142,21 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y} inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1: node.fluid_code.add_layer(
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
inputs=val_y,
output=y_squeeze,
param_attr={'axes': [0]})
inputs['y'] = y_squeeze
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
else:
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None) "matmul", inputs=inputs, output=node, param_attr=None)
#if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
# y_squeeze = val_y.layer_name + '_squeeze'
# node.fluid_code.add_layer(
# "squeeze",
# inputs=val_y,
# output=y_squeeze,
# param_attr={'axes': [0]})
# inputs['y'] = y_squeeze
# node.fluid_code.add_layer(
# "matmul", inputs=inputs, output=node, param_attr=None)
#else:
# node.fluid_code.add_layer(
# "matmul", inputs=inputs, output=node, param_attr=None)
@print_mapping_info @print_mapping_info
def BatchNormalization(self, node): def BatchNormalization(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册