diff --git a/mace/python/tools/tf_dsp_converter_lib.py b/mace/python/tools/tf_dsp_converter_lib.py index 65369b4625a9ead2c176921d661844cbb731a3d8..493544da30cd95c0a10940db1f4b850f5988d824 100644 --- a/mace/python/tools/tf_dsp_converter_lib.py +++ b/mace/python/tools/tf_dsp_converter_lib.py @@ -95,12 +95,19 @@ def add_shape_const_node(net_def, op, values, name): def convert_op_outputs(mace_op_def, tf_op): + mace_op_def.out_max_byte_size.extend( + [max_elem_size(output) for output in tf_op.outputs]) mace_op_def.output_type.extend( [tf_dtype_2_mace_dtype(output.dtype) for output in tf_op.outputs]) output_shapes = [] for output in tf_op.outputs: output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) + shape_list = output.shape.as_list() + if not shape_list: + shape_list = [1] + elif len(shape_list) == 2: + shape_list = [1, 1, shape_list[0], shape_list[1]] + output_shape.dims.extend(shape_list) output_shapes.append(output_shape) mace_op_def.output_shape.extend(output_shapes) @@ -159,8 +166,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): op_def.input.append(input_tensor.name) op_def.input.extend([t.name for t in s2b_op.inputs[1:]]) op_def.input.extend([min_tensor.name, max_tensor.name]) - op_def.out_max_byte_size.extend( - [max_elem_size(out) for out in quantize_op.outputs]) convert_op_outputs(op_def, quantize_op) elif len(first_op.outputs) > 0 and \ first_op.type == 'QuantizedReshape' and \ @@ -193,9 +198,71 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): op_def.type = dsp_ops.map_nn_op('QuantizedSoftmax') op_def.input.extend( [input_tensor.name, min_tensor.name, max_tensor.name]) - op_def.out_max_byte_size.extend( - [max_elem_size(out) for out in quantize_reshape_op.outputs]) convert_op_outputs(op_def, quantize_reshape_op) + # remove Squeeze + elif len(first_op.outputs) > 0 and \ + first_op.type == 'Requantize' and \ + len(first_op.outputs[0].consumers()) > 0 and \ + first_op.outputs[0].consumers()[0].type == 'Dequantize' and \ + len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) \ + > 0 and \ + first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type \ + == 'Squeeze': + dequantize_op = first_op.outputs[0].consumers()[0] + squeeze_op = dequantize_op.outputs[0].consumers()[0] + reshape_op = squeeze_op.outputs[0].consumers()[0] + min_op = reshape_op.outputs[0].consumers()[0] + max_op = reshape_op.outputs[0].consumers()[1] + quantize_op = min_op.outputs[0].consumers()[0] + + resolved_ops.add(dequantize_op.name) + resolved_ops.add(squeeze_op.name) + resolved_ops.add(reshape_op.name) + resolved_ops.add(min_op.name) + resolved_ops.add(max_op.name) + resolved_ops.add(quantize_op.name) + + op_def.name = quantize_op.name + op_def.input.extend([t.name for t in first_op.inputs]) + convert_op_outputs(op_def, quantize_op) + + # Squeeze -> Softmax + next_op = quantize_op.outputs[0].consumers()[0] \ + if len(quantize_op.outputs) > 0 else None + dequantize_op = next_op.outputs[0].consumers()[0] \ + if next_op and len(next_op.outputs) > 0 and \ + next_op.type == 'QuantizedReshape' and \ + len(next_op.outputs[0].consumers()) > 0 else None + softmax_op = dequantize_op.outputs[0].consumers()[0]\ + if dequantize_op and len(dequantize_op.outputs) > 0 and \ + dequantize_op.type == 'Dequantize' and \ + len(dequantize_op.outputs[0].consumers()) > 0 else None + if softmax_op and softmax_op.type == 'Softmax': + reshape_op = softmax_op.outputs[0].consumers()[0] + min_op = reshape_op.outputs[0].consumers()[0] + max_op = reshape_op.outputs[0].consumers()[1] + quantize_op = min_op.outputs[0].consumers()[0] + quantize_reshape_op = quantize_op.outputs[0].consumers()[0] + + resolved_ops.add(next_op.name) + resolved_ops.add(dequantize_op.name) + resolved_ops.add(softmax_op.name) + resolved_ops.add(reshape_op.name) + resolved_ops.add(min_op.name) + resolved_ops.add(max_op.name) + resolved_ops.add(quantize_op.name) + resolved_ops.add(quantize_reshape_op.name) + + softmax_op_def = net_def.op.add() + softmax_op_def.padding = padding_mode['NA'] + softmax_op_def.name = quantize_reshape_op.name + softmax_op_def.type = dsp_ops.map_nn_op('QuantizedSoftmax') + softmax_op_def.input.extend([ + get_tensor_name_from_op(op_def.name, 0), + get_tensor_name_from_op(op_def.name, 1), + get_tensor_name_from_op(op_def.name, 2)]) + convert_op_outputs(softmax_op_def, quantize_reshape_op) + elif len(first_op.outputs) > 0 and first_op.type == 'Dequantize' and \ len(first_op.outputs[0].consumers()) > 0 and \ first_op.outputs[0].consumers()[0].type == 'Tanh': @@ -220,8 +287,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type) op_def.input.extend( [input_tensor.name, min_tensor.name, max_tensor.name]) - op_def.out_max_byte_size.extend( - [max_elem_size(out) for out in quantize_op.outputs]) convert_op_outputs(op_def, quantize_op) # tanh is last op else: @@ -251,8 +316,6 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): get_tensor_name_from_op(op_def.name, 1), get_tensor_name_from_op(op_def.name, 2) ]) - new_tanh_op_def.out_max_byte_size.extend( - [max_elem_size(tanh_op.outputs[0])]) convert_op_outputs(new_tanh_op_def, tanh_op) elif has_padding_and_strides(first_op): op_def.padding = padding_mode[first_op.get_attr('padding')] @@ -266,19 +329,13 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): strides_tensor = add_shape_const_node(net_def, first_op, strides, 'strides') op_def.input.extend([strides_tensor]) - op_def.out_max_byte_size.extend( - [max_elem_size(out) for out in first_op.outputs]) convert_op_outputs(op_def, first_op) elif is_node_flatten_reshape(first_op): op_def.type = 'Flatten' - op_def.input.extend([t.name for t in first_op.inputs]) - op_def.out_max_byte_size.extend( - [max_elem_size(out) for out in first_op.outputs]) + op_def.input.extend([first_op.inputs[0].name]) convert_op_outputs(op_def, first_op) elif dsp_ops.has_op(first_op.type): op_def.input.extend([t.name for t in first_op.inputs]) - op_def.out_max_byte_size.extend( - [max_elem_size(out) for out in first_op.outputs]) convert_op_outputs(op_def, first_op) else: raise Exception('Unsupported op: ', first_op) @@ -478,7 +535,8 @@ def fuse_quantize(net_def, input_node, output_node): skip_ops = skip_ops.union( [flatten_op.name, minf_op.name, maxf_op.name]) skip_tensors = skip_tensors.union( - [flatten_op.input[1], minf_op.input[1], maxf_op.input[1]]) + [minf_op.input[0], maxf_op.input[0], + quantize_op.input[1], quantize_op.input[2]]) quantize_op.type = 'AutoQuantize' del quantize_op.input[1:]