diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index 280b5b4f1a27a892e243283c6798164c86d73bb7..2466cac432f6a9efd2f941c41635c7139f0b395d 100644 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -111,7 +111,7 @@ class ONNXGraphDataNode(GraphNode): if isinstance(self.layer, ValueInfoProto): values = self.layer.type.tensor_type.shape.dim out_shapes = list() - out_shapes.append([dim.dim_value for dim in values]) + out_shapes.append([-1 if dim.dim_value == 0 else dim.dim_value for dim in values]) return out_shapes else: values = self.layer.dims @@ -330,7 +330,7 @@ class ONNXGraph(Graph): 'dtype': TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], 'shape': - [dim.dim_value for dim in item.type.tensor_type.shape.dim], + [-1 if dim.dim_value == 0 else dim.dim_value for dim in item.type.tensor_type.shape.dim], 'external': False } diff --git a/x2paddle/decoder/onnx_shape_inference.py b/x2paddle/decoder/onnx_shape_inference.py index ff3fe71c32a6a435f171cc76321ce8acac3c37d3..05fba0cc7966a8aaaa6b5e90278f4eaf5640ac90 100644 --- a/x2paddle/decoder/onnx_shape_inference.py +++ b/x2paddle/decoder/onnx_shape_inference.py @@ -151,7 +151,6 @@ class SymbolicShapeInference: 'TopK': self._infer_TopK, 'Unsqueeze': self._infer_Unsqueeze, 'Where': self._infer_symbolic_compute_ops, - 'Transpose': self._infer_Transpose, 'ZipMap': self._infer_ZipMap } self.run_ = True @@ -731,15 +730,6 @@ class SymbolicShapeInference: helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) - def _infer_Transpose(self, node): - input_shape = self._get_shape(node, 0) - perm = get_attribute(node, 'perm') - output_shape = np.array(input_shape)[perm].tolist() - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - def _infer_Compress(self, node): input_shape = self._get_shape(node, 0) # create a new symbolic dimension for Compress output diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index ef3c670a114902ed6a1a8930dc44993e3ecefe26..a23a3ce247fe2f5f83239ab0b2f5bf2780e42957 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -255,11 +255,6 @@ class OpSet9(): self.input_shapes.append(node.out_shapes[0]) shape = node.out_shapes[0] - for i, dim_shape in enumerate(shape): - if dim_shape == 0 and i == 0: - shape[i] = 1 - if dim_shape == 0 and i != 0: - assert 'shape of input is not assigned' attr = { "dtype": string(node.dtype), "shape": shape,