提交 f754d720 编写于 作者: C Channingss

rm shape_infer of Transpose

上级 c66440ea
...@@ -111,7 +111,7 @@ class ONNXGraphDataNode(GraphNode): ...@@ -111,7 +111,7 @@ class ONNXGraphDataNode(GraphNode):
if isinstance(self.layer, ValueInfoProto): if isinstance(self.layer, ValueInfoProto):
values = self.layer.type.tensor_type.shape.dim values = self.layer.type.tensor_type.shape.dim
out_shapes = list() out_shapes = list()
out_shapes.append([dim.dim_value for dim in values]) out_shapes.append([-1 if dim.dim_value == 0 else dim.dim_value for dim in values])
return out_shapes return out_shapes
else: else:
values = self.layer.dims values = self.layer.dims
...@@ -330,7 +330,7 @@ class ONNXGraph(Graph): ...@@ -330,7 +330,7 @@ class ONNXGraph(Graph):
'dtype': 'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape': 'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim], [-1 if dim.dim_value == 0 else dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': False 'external': False
} }
......
...@@ -151,7 +151,6 @@ class SymbolicShapeInference: ...@@ -151,7 +151,6 @@ class SymbolicShapeInference:
'TopK': self._infer_TopK, 'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze, 'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops, 'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap 'ZipMap': self._infer_ZipMap
} }
self.run_ = True self.run_ = True
...@@ -731,15 +730,6 @@ class SymbolicShapeInference: ...@@ -731,15 +730,6 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], output_type, helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0))) self._get_shape(node, 0)))
def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node): def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0) input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output # create a new symbolic dimension for Compress output
......
...@@ -255,11 +255,6 @@ class OpSet9(): ...@@ -255,11 +255,6 @@ class OpSet9():
self.input_shapes.append(node.out_shapes[0]) self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0] shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
attr = { attr = {
"dtype": string(node.dtype), "dtype": string(node.dtype),
"shape": shape, "shape": shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册