提交 a99c32db 编写于 作者: C Channingss

update

上级 7b73b705
......@@ -183,10 +183,10 @@ class ONNXGraph(Graph):
return False
return True
def fix_unkown_input_shape(self, vi):
def fix_input_shape(self, vi):
shape = self.get_symbolic_shape(vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
"Input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
right_shape_been_input = False
while not right_shape_been_input:
......@@ -200,8 +200,8 @@ class ONNXGraph(Graph):
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
if shape == 'N':
break
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
......@@ -214,7 +214,7 @@ class ONNXGraph(Graph):
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
if self.define_input_shape:
self.check_input_shape(ipt_vi)
self.fix_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)
def get_output_nodes(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册