diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index 0289a1912a6ef91f7ebb185c40ebe4ae5011e61a..29996a076734fbdcba7229104e2be045a7a8063d 100644 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -183,10 +183,10 @@ class ONNXGraph(Graph): return False return True - def fix_unkown_input_shape(self, vi): + def fix_input_shape(self, vi): shape = self.get_symbolic_shape(vi.type.tensor_type.shape.dim) print( - "Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape." + "Input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape." .format(vi.name, shape)) right_shape_been_input = False while not right_shape_been_input: @@ -200,8 +200,8 @@ class ONNXGraph(Graph): print("Only 1 dimension can be -1, type again:)") else: right_shape_been_input = True - if shape == 'N': - break + if shape == 'N': + break shape = [int(dim) for dim in shape.strip().split(',')] assert shape.count(-1) <= 1, "Only one dimension can be -1" self.fixed_input_shape[vi.name] = shape @@ -214,7 +214,7 @@ class ONNXGraph(Graph): for ipt_vi in self.graph.input: if ipt_vi.name not in inner_nodes: if self.define_input_shape: - self.check_input_shape(ipt_vi) + self.fix_input_shape(ipt_vi) self.place_holder_nodes.append(ipt_vi.name) def get_output_nodes(self):