提交 135eb45e 编写于 作者: C Channingss

update

上级 a99c32db
......@@ -168,8 +168,10 @@ class ONNXGraph(Graph):
print('shape inferencing ...')
infered_graph = SymbolicShapeInference.infer_shapes(
self.model, fixed_input_shape=self.fixed_input_shape)
#infered_graph = None
if infered_graph is None:
infered_model = shape_inference.infer_shapes(self.model)
onnx.save(infered_model, 'infered_model.onnx')
self.graph = infered_model.graph
else:
self.graph = infered_graph
......@@ -196,15 +198,21 @@ class ONNXGraph(Graph):
except:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: ")
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
#if shape.count("-1") > 1:
# print("Only 1 dimension can be -1, type again:)")
#else:
right_shape_been_input = True
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
return
#shape = [int(dim) for dim in shape.strip().split(',')]
shape_ = []
for dim in shape.strip().split(','):
if dim.isdigit():
shape_.append(int(dim))
else:
shape_.append(dim)
#assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape_
def get_place_holder_nodes(self):
"""
......
......@@ -151,7 +151,7 @@ class SymbolicShapeInference:
'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose,
#'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap
}
self.run_ = True
......@@ -731,14 +731,16 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0)))
def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
#def _infer_Transpose(self, node):
# input_shape = self._get_shape(node, 0)
# perm = get_attribute(node, 'perm')
# output_shape = np.array(input_shape)[perm].tolist()
# print(input_shape)
# print(out_shape)
# vi = self.known_vi_[node.output[0]]
# vi.CopyFrom(
# helper.make_tensor_value_info(node.output[0], self.known_vi_[
# node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
......@@ -856,21 +858,28 @@ class SymbolicShapeInference:
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, new_shape))
#if node.output[0] == '173':
# print('yyyy')
if node.input[0] in self.sympy_data_:
assert 0 == get_attribute(node, 'axis',
0) # only handle 1D sympy compute
idx = self._get_value(node, 1)
data = self.sympy_data_[node.input[0]]
print(data)
print(node.output[0])
if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[0]] = [
data[int(i)] for i in idx
]
else:
print(node.output[0], 'else')
print(int(idx))
self.sympy_data_[node.output[0]] = data[int(idx)]
else:
assert idx == 0
self.sympy_data_[node.output[0]] = data
def _infer_GatherElements(self, node):
indices_shape = self._get_shape(node, 1)
......@@ -1419,7 +1428,7 @@ class SymbolicShapeInference:
if self.verbose_ > 2:
print(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input):
print(' Input {}: {} {}€5€5€5€5€5'.format(
print(' Input {}: {} {}'.format(
i, name, 'initializer'
if name in self.initializers_ else ''))
......@@ -1458,7 +1467,6 @@ class SymbolicShapeInference:
if node.output[i_o] in self.sympy_data_:
print(' Sympy Data: ' + str(self.sympy_data_[
node.output[i_o]]))
if None in out_shape or out_type_undefined:
if self.auto_merge_:
if node.op_type in [
......@@ -1570,7 +1578,7 @@ class SymbolicShapeInference:
fixed_input_shape=None,
auto_merge=True,
guess_output_rank=False,
verbose=0):
verbose=3):
if get_opset(in_mp) < 7:
print('Only support shape inferencing models of opset 7 and above.')
return
......@@ -1578,16 +1586,16 @@ class SymbolicShapeInference:
int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(
in_mp, input_shapes=fixed_input_shape)
in_mp, input_shapes=fixed_input_shape)
try:
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(
in_mp)
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
#if not all_shapes_inferred:
# symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
# symbolic_shape_inference.out_mp_)
onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except:
return None
pass
return symbolic_shape_inference.out_mp_.graph
......@@ -257,7 +257,7 @@ class OpSet9():
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
shape[i] = -1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
attr = {
......@@ -1142,19 +1142,21 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
inputs=val_y,
output=y_squeeze,
param_attr={'axes': [0]})
inputs['y'] = y_squeeze
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
else:
node.fluid_code.add_layer(
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
#if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
# y_squeeze = val_y.layer_name + '_squeeze'
# node.fluid_code.add_layer(
# "squeeze",
# inputs=val_y,
# output=y_squeeze,
# param_attr={'axes': [0]})
# inputs['y'] = y_squeeze
# node.fluid_code.add_layer(
# "matmul", inputs=inputs, output=node, param_attr=None)
#else:
# node.fluid_code.add_layer(
# "matmul", inputs=inputs, output=node, param_attr=None)
@print_mapping_info
def BatchNormalization(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册