提交 9fd509dc 编写于 作者: W wuchenghui

fix graph convert

上级 48a038ca
...@@ -77,7 +77,7 @@ message OperatorDef { ...@@ -77,7 +77,7 @@ message OperatorDef {
optional string name = 3; optional string name = 3;
optional string type = 4; optional string type = 4;
repeated Argument arg = 5; repeated Argument arg = 5;
optional OutputShape output_shape = 6; repeated OutputShape output_shape = 6;
// Memory optimization: only support one single output op // Memory optimization: only support one single output op
optional int32 mem_id = 10 [default = -1]; optional int32 mem_id = 10 [default = -1];
......
...@@ -18,15 +18,6 @@ def convert_tensor(op, tensor): ...@@ -18,15 +18,6 @@ def convert_tensor(op, tensor):
tensor.name = op.outputs[0].name tensor.name = op.outputs[0].name
shape = list(tf_tensor.shape) shape = list(tf_tensor.shape)
if (op.name.find('pointwise_kernel') != -1 or
op.name.find('depthwise_kernel') != -1 or
op.name.endswith('weights') or
op.name.endswith('kernel')) \
and op.outputs[0].consumers()[0].type.find('Conv') != -1:
if op.outputs[0].consumers()[0].get_attr('data_format') == 'NHWC':
tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1))
shape = [shape[3], shape[2], shape[0], shape[1]]
# print (tensor.name, shape)
tensor.dims.extend(shape) tensor.dims.extend(shape)
tf_dt = op.get_attr('dtype') tf_dt = op.get_attr('dtype')
...@@ -66,6 +57,12 @@ def convert_ops(unresolved_ops, net_def): ...@@ -66,6 +57,12 @@ def convert_ops(unresolved_ops, net_def):
op_def.type = first_op.type op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs]) op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs]) op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
padding_arg = op_def.arg.add() padding_arg = op_def.arg.add()
padding_arg.name = 'padding' padding_arg.name = 'padding'
padding_arg.i = padding_mode[first_op.get_attr('padding')] padding_arg.i = padding_mode[first_op.get_attr('padding')]
...@@ -74,7 +71,7 @@ def convert_ops(unresolved_ops, net_def): ...@@ -74,7 +71,7 @@ def convert_ops(unresolved_ops, net_def):
strides_arg.ints.extend(first_op.get_attr('strides')[1:3]) strides_arg.ints.extend(first_op.get_attr('strides')[1:3])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
data_format_arg.s = 'NCHW' data_format_arg.s = 'NHWC'
if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd': if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd':
bias_add_op = unresolved_ops[1] bias_add_op = unresolved_ops[1]
...@@ -105,6 +102,12 @@ def convert_ops(unresolved_ops, net_def): ...@@ -105,6 +102,12 @@ def convert_ops(unresolved_ops, net_def):
op_def.type = 'BatchNorm' op_def.type = 'BatchNorm'
op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon]) op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon])
op_def.output.extend([output.name for output in add_1_op.outputs]) op_def.output.extend([output.name for output in add_1_op.outputs])
output_shapes = []
for output in add_1_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
resolved_count = 7 resolved_count = 7
elif first_op.type == 'Relu6': elif first_op.type == 'Relu6':
...@@ -113,6 +116,12 @@ def convert_ops(unresolved_ops, net_def): ...@@ -113,6 +116,12 @@ def convert_ops(unresolved_ops, net_def):
op_def.type = 'Relu' op_def.type = 'Relu'
op_def.input.extend([input.name for input in first_op.inputs]) op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs]) op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
max_limit_arg = op_def.arg.add() max_limit_arg = op_def.arg.add()
max_limit_arg.name = 'max_limit' max_limit_arg.name = 'max_limit'
max_limit_arg.f = 6 max_limit_arg.f = 6
...@@ -122,6 +131,12 @@ def convert_ops(unresolved_ops, net_def): ...@@ -122,6 +131,12 @@ def convert_ops(unresolved_ops, net_def):
op_def.type = 'Pooling' op_def.type = 'Pooling'
op_def.input.extend([input.name for input in first_op.inputs]) op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs]) op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
pooling_type_arg = op_def.arg.add() pooling_type_arg = op_def.arg.add()
pooling_type_arg.name = 'pooling_type' pooling_type_arg.name = 'pooling_type'
pooling_type_arg.i = pooling_type_mode[first_op.type] pooling_type_arg.i = pooling_type_mode[first_op.type]
...@@ -136,21 +151,46 @@ def convert_ops(unresolved_ops, net_def): ...@@ -136,21 +151,46 @@ def convert_ops(unresolved_ops, net_def):
kernels_arg.ints.extend(first_op.get_attr('ksize')[1:3]) kernels_arg.ints.extend(first_op.get_attr('ksize')[1:3])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
data_format_arg.s = 'NCHW' data_format_arg.s = 'NHWC'
elif first_op.type == 'Add': elif first_op.type == 'Add':
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = first_op.name op_def.name = first_op.name
op_def.type = "AddN" op_def.type = "AddN"
op_def.input.extend([input.name for input in first_op.inputs]) op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs]) op_def.output.extend([output.name for output in first_op.outputs])
elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND', 'BatchToSpaceND']: output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
elif first_op.type == 'ConcatV2':
op_def = net_def.op.add()
op_def.name = first_op.name
op_def.type = "Concat"
op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND',
'BatchToSpaceND', 'BiasAdd', 'FusedBatchNorm']:
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = first_op.name op_def.name = first_op.name
op_def.type = first_op.type op_def.type = first_op.type
op_def.input.extend([input.name for input in first_op.inputs]) op_def.input.extend([input.name for input in first_op.inputs])
op_def.output.extend([output.name for output in first_op.outputs]) op_def.output.extend([output.name for output in first_op.outputs])
output_shapes = []
for output in first_op.outputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
else: else:
raise Exception('Unknown Op: ' + first_op.name) raise Exception('Unknown Op: %s, type: %s' % (first_op.name, first_op.type))
pass pass
for i in range(resolved_count): for i in range(resolved_count):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册