提交 4c535c33 编写于 作者: W wuchenghui

fix dsp converter

上级 1c874957
...@@ -95,12 +95,19 @@ def add_shape_const_node(net_def, op, values, name): ...@@ -95,12 +95,19 @@ def add_shape_const_node(net_def, op, values, name):
def convert_op_outputs(mace_op_def, tf_op): def convert_op_outputs(mace_op_def, tf_op):
mace_op_def.out_max_byte_size.extend(
[max_elem_size(output) for output in tf_op.outputs])
mace_op_def.output_type.extend( mace_op_def.output_type.extend(
[tf_dtype_2_mace_dtype(output.dtype) for output in tf_op.outputs]) [tf_dtype_2_mace_dtype(output.dtype) for output in tf_op.outputs])
output_shapes = [] output_shapes = []
for output in tf_op.outputs: for output in tf_op.outputs:
output_shape = mace_pb2.OutputShape() output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list()) shape_list = output.shape.as_list()
if not shape_list:
shape_list = [1]
elif len(shape_list) == 2:
shape_list = [1, 1, shape_list[0], shape_list[1]]
output_shape.dims.extend(shape_list)
output_shapes.append(output_shape) output_shapes.append(output_shape)
mace_op_def.output_shape.extend(output_shapes) mace_op_def.output_shape.extend(output_shapes)
...@@ -159,8 +166,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -159,8 +166,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
op_def.input.append(input_tensor.name) op_def.input.append(input_tensor.name)
op_def.input.extend([t.name for t in s2b_op.inputs[1:]]) op_def.input.extend([t.name for t in s2b_op.inputs[1:]])
op_def.input.extend([min_tensor.name, max_tensor.name]) op_def.input.extend([min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend(
[max_elem_size(out) for out in quantize_op.outputs])
convert_op_outputs(op_def, quantize_op) convert_op_outputs(op_def, quantize_op)
elif len(first_op.outputs) > 0 and \ elif len(first_op.outputs) > 0 and \
first_op.type == 'QuantizedReshape' and \ first_op.type == 'QuantizedReshape' and \
...@@ -193,9 +198,71 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -193,9 +198,71 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
op_def.type = dsp_ops.map_nn_op('QuantizedSoftmax') op_def.type = dsp_ops.map_nn_op('QuantizedSoftmax')
op_def.input.extend( op_def.input.extend(
[input_tensor.name, min_tensor.name, max_tensor.name]) [input_tensor.name, min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend(
[max_elem_size(out) for out in quantize_reshape_op.outputs])
convert_op_outputs(op_def, quantize_reshape_op) convert_op_outputs(op_def, quantize_reshape_op)
# remove Squeeze
elif len(first_op.outputs) > 0 and \
first_op.type == 'Requantize' and \
len(first_op.outputs[0].consumers()) > 0 and \
first_op.outputs[0].consumers()[0].type == 'Dequantize' and \
len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) \
> 0 and \
first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type \
== 'Squeeze':
dequantize_op = first_op.outputs[0].consumers()[0]
squeeze_op = dequantize_op.outputs[0].consumers()[0]
reshape_op = squeeze_op.outputs[0].consumers()[0]
min_op = reshape_op.outputs[0].consumers()[0]
max_op = reshape_op.outputs[0].consumers()[1]
quantize_op = min_op.outputs[0].consumers()[0]
resolved_ops.add(dequantize_op.name)
resolved_ops.add(squeeze_op.name)
resolved_ops.add(reshape_op.name)
resolved_ops.add(min_op.name)
resolved_ops.add(max_op.name)
resolved_ops.add(quantize_op.name)
op_def.name = quantize_op.name
op_def.input.extend([t.name for t in first_op.inputs])
convert_op_outputs(op_def, quantize_op)
# Squeeze -> Softmax
next_op = quantize_op.outputs[0].consumers()[0] \
if len(quantize_op.outputs) > 0 else None
dequantize_op = next_op.outputs[0].consumers()[0] \
if next_op and len(next_op.outputs) > 0 and \
next_op.type == 'QuantizedReshape' and \
len(next_op.outputs[0].consumers()) > 0 else None
softmax_op = dequantize_op.outputs[0].consumers()[0]\
if dequantize_op and len(dequantize_op.outputs) > 0 and \
dequantize_op.type == 'Dequantize' and \
len(dequantize_op.outputs[0].consumers()) > 0 else None
if softmax_op and softmax_op.type == 'Softmax':
reshape_op = softmax_op.outputs[0].consumers()[0]
min_op = reshape_op.outputs[0].consumers()[0]
max_op = reshape_op.outputs[0].consumers()[1]
quantize_op = min_op.outputs[0].consumers()[0]
quantize_reshape_op = quantize_op.outputs[0].consumers()[0]
resolved_ops.add(next_op.name)
resolved_ops.add(dequantize_op.name)
resolved_ops.add(softmax_op.name)
resolved_ops.add(reshape_op.name)
resolved_ops.add(min_op.name)
resolved_ops.add(max_op.name)
resolved_ops.add(quantize_op.name)
resolved_ops.add(quantize_reshape_op.name)
softmax_op_def = net_def.op.add()
softmax_op_def.padding = padding_mode['NA']
softmax_op_def.name = quantize_reshape_op.name
softmax_op_def.type = dsp_ops.map_nn_op('QuantizedSoftmax')
softmax_op_def.input.extend([
get_tensor_name_from_op(op_def.name, 0),
get_tensor_name_from_op(op_def.name, 1),
get_tensor_name_from_op(op_def.name, 2)])
convert_op_outputs(softmax_op_def, quantize_reshape_op)
elif len(first_op.outputs) > 0 and first_op.type == 'Dequantize' and \ elif len(first_op.outputs) > 0 and first_op.type == 'Dequantize' and \
len(first_op.outputs[0].consumers()) > 0 and \ len(first_op.outputs[0].consumers()) > 0 and \
first_op.outputs[0].consumers()[0].type == 'Tanh': first_op.outputs[0].consumers()[0].type == 'Tanh':
...@@ -220,8 +287,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -220,8 +287,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type) op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type)
op_def.input.extend( op_def.input.extend(
[input_tensor.name, min_tensor.name, max_tensor.name]) [input_tensor.name, min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend(
[max_elem_size(out) for out in quantize_op.outputs])
convert_op_outputs(op_def, quantize_op) convert_op_outputs(op_def, quantize_op)
# tanh is last op # tanh is last op
else: else:
...@@ -251,8 +316,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -251,8 +316,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
get_tensor_name_from_op(op_def.name, 1), get_tensor_name_from_op(op_def.name, 1),
get_tensor_name_from_op(op_def.name, 2) get_tensor_name_from_op(op_def.name, 2)
]) ])
new_tanh_op_def.out_max_byte_size.extend(
[max_elem_size(tanh_op.outputs[0])])
convert_op_outputs(new_tanh_op_def, tanh_op) convert_op_outputs(new_tanh_op_def, tanh_op)
elif has_padding_and_strides(first_op): elif has_padding_and_strides(first_op):
op_def.padding = padding_mode[first_op.get_attr('padding')] op_def.padding = padding_mode[first_op.get_attr('padding')]
...@@ -266,19 +329,13 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -266,19 +329,13 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
strides_tensor = add_shape_const_node(net_def, first_op, strides, strides_tensor = add_shape_const_node(net_def, first_op, strides,
'strides') 'strides')
op_def.input.extend([strides_tensor]) op_def.input.extend([strides_tensor])
op_def.out_max_byte_size.extend(
[max_elem_size(out) for out in first_op.outputs])
convert_op_outputs(op_def, first_op) convert_op_outputs(op_def, first_op)
elif is_node_flatten_reshape(first_op): elif is_node_flatten_reshape(first_op):
op_def.type = 'Flatten' op_def.type = 'Flatten'
op_def.input.extend([t.name for t in first_op.inputs]) op_def.input.extend([first_op.inputs[0].name])
op_def.out_max_byte_size.extend(
[max_elem_size(out) for out in first_op.outputs])
convert_op_outputs(op_def, first_op) convert_op_outputs(op_def, first_op)
elif dsp_ops.has_op(first_op.type): elif dsp_ops.has_op(first_op.type):
op_def.input.extend([t.name for t in first_op.inputs]) op_def.input.extend([t.name for t in first_op.inputs])
op_def.out_max_byte_size.extend(
[max_elem_size(out) for out in first_op.outputs])
convert_op_outputs(op_def, first_op) convert_op_outputs(op_def, first_op)
else: else:
raise Exception('Unsupported op: ', first_op) raise Exception('Unsupported op: ', first_op)
...@@ -478,7 +535,8 @@ def fuse_quantize(net_def, input_node, output_node): ...@@ -478,7 +535,8 @@ def fuse_quantize(net_def, input_node, output_node):
skip_ops = skip_ops.union( skip_ops = skip_ops.union(
[flatten_op.name, minf_op.name, maxf_op.name]) [flatten_op.name, minf_op.name, maxf_op.name])
skip_tensors = skip_tensors.union( skip_tensors = skip_tensors.union(
[flatten_op.input[1], minf_op.input[1], maxf_op.input[1]]) [minf_op.input[0], maxf_op.input[0],
quantize_op.input[1], quantize_op.input[2]])
quantize_op.type = 'AutoQuantize' quantize_op.type = 'AutoQuantize'
del quantize_op.input[1:] del quantize_op.input[1:]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册