提交 f754d720 编写于 作者: C Channingss

rm shape_infer of Transpose

上级 c66440ea
......@@ -111,7 +111,7 @@ class ONNXGraphDataNode(GraphNode):
if isinstance(self.layer, ValueInfoProto):
values = self.layer.type.tensor_type.shape.dim
out_shapes = list()
out_shapes.append([dim.dim_value for dim in values])
out_shapes.append([-1 if dim.dim_value == 0 else dim.dim_value for dim in values])
return out_shapes
else:
values = self.layer.dims
......@@ -330,7 +330,7 @@ class ONNXGraph(Graph):
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
[-1 if dim.dim_value == 0 else dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': False
}
......
......@@ -151,7 +151,6 @@ class SymbolicShapeInference:
'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap
}
self.run_ = True
......@@ -731,15 +730,6 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0)))
def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output
......
......@@ -255,11 +255,6 @@ class OpSet9():
self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
attr = {
"dtype": string(node.dtype),
"shape": shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册