diff --git a/mace/python/tools/dsp_ops.py b/mace/python/tools/dsp_ops.py index 4ce90a0b27235914188ba34232c7f5557d44ef75..8589879b6fae0d5468795d0ad466269d7d8a4e2a 100644 --- a/mace/python/tools/dsp_ops.py +++ b/mace/python/tools/dsp_ops.py @@ -22,6 +22,7 @@ class DspOps(object): 'QuantizedSpaceToBatchND': 'QuantizedSpaceToBatchND_8', 'QuantizedBatchToSpaceND': 'QuantizedBatchToSpaceND_8', 'QuantizedSoftmax': 'QuantizedSoftmax_8', + 'QuantizedTanh': 'QuantizedTanh_8', 'Min': 'Min_f', 'Max': 'Max_f', 'QuantizeV2': 'Quantize', diff --git a/mace/python/tools/tf_dsp_converter_lib.py b/mace/python/tools/tf_dsp_converter_lib.py index f53c25aa29753593ef21d670a5325a72403347da..7c2da02ab47e215eb2e5c43ec30e83498e2863da 100644 --- a/mace/python/tools/tf_dsp_converter_lib.py +++ b/mace/python/tools/tf_dsp_converter_lib.py @@ -164,6 +164,55 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name]) op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_reshape_op.outputs]) convert_op_outputs(op_def, quantize_reshape_op) + elif len(first_op.outputs) > 0 and first_op.type == 'Dequantize' \ + and len(first_op.outputs[0].consumers()) > 0 \ + and first_op.outputs[0].consumers()[0].type == 'Tanh': + input_tensor = first_op.inputs[0] + min_tensor = first_op.inputs[1] + max_tensor = first_op.inputs[2] + tanh_op = first_op.outputs[0].consumers()[0] + + # if not last op + resolved_ops.add(tanh_op.name) + if tanh_op.outputs[0].consumers(): + reshape_op = tanh_op.outputs[0].consumers()[0] + min_op = reshape_op.outputs[0].consumers()[0] + max_op = reshape_op.outputs[0].consumers()[1] + quantize_op = min_op.outputs[0].consumers()[0] + resolved_ops.add(reshape_op.name) + resolved_ops.add(min_op.name) + resolved_ops.add(max_op.name) + resolved_ops.add(quantize_op.name) + + op_def.name = quantize_op.name + op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type) + op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name]) + op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_op.outputs]) + convert_op_outputs(op_def, quantize_op) + # tanh is last op + else: + op_def.name = tanh_op.name + '/QuantizedTanh' + op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type) + op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name]) + op_def.out_max_byte_size.extend([max_elem_size(input_tensor), + max_elem_size(min_tensor), + max_elem_size(max_tensor)]) + op_def.output_type.extend([mace_pb2.DT_UINT8, mace_pb2.DT_FLOAT, mace_pb2.DT_FLOAT]) + output_shapes = [] + for output in first_op.inputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + + new_tanh_op_def = net_def.op.add() + new_tanh_op_def.name = tanh_op.name + new_tanh_op_def.type = dsp_ops.map_nn_op('Dequantize') + new_tanh_op_def.input.extend([get_tensor_name_from_op(op_def.name, 0), + get_tensor_name_from_op(op_def.name, 1), + get_tensor_name_from_op(op_def.name, 2)]) + new_tanh_op_def.out_max_byte_size.extend([max_elem_size(tanh_op.outputs[0])]) + convert_op_outputs(new_tanh_op_def, tanh_op) elif has_padding_and_strides(first_op): op_def.padding = padding_mode[first_op.get_attr('padding')] op_def.input.extend([t.name for t in first_op.inputs])