提交 814f4324 编写于 作者: W wuchenghui

support QuantizedTanh

上级 1c6743fc
......@@ -22,6 +22,7 @@ class DspOps(object):
'QuantizedSpaceToBatchND': 'QuantizedSpaceToBatchND_8',
'QuantizedBatchToSpaceND': 'QuantizedBatchToSpaceND_8',
'QuantizedSoftmax': 'QuantizedSoftmax_8',
'QuantizedTanh': 'QuantizedTanh_8',
'Min': 'Min_f',
'Max': 'Max_f',
'QuantizeV2': 'Quantize',
......
......@@ -164,6 +164,55 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_reshape_op.outputs])
convert_op_outputs(op_def, quantize_reshape_op)
elif len(first_op.outputs) > 0 and first_op.type == 'Dequantize' \
and len(first_op.outputs[0].consumers()) > 0 \
and first_op.outputs[0].consumers()[0].type == 'Tanh':
input_tensor = first_op.inputs[0]
min_tensor = first_op.inputs[1]
max_tensor = first_op.inputs[2]
tanh_op = first_op.outputs[0].consumers()[0]
# if not last op
resolved_ops.add(tanh_op.name)
if tanh_op.outputs[0].consumers():
reshape_op = tanh_op.outputs[0].consumers()[0]
min_op = reshape_op.outputs[0].consumers()[0]
max_op = reshape_op.outputs[0].consumers()[1]
quantize_op = min_op.outputs[0].consumers()[0]
resolved_ops.add(reshape_op.name)
resolved_ops.add(min_op.name)
resolved_ops.add(max_op.name)
resolved_ops.add(quantize_op.name)
op_def.name = quantize_op.name
op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type)
op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_op.outputs])
convert_op_outputs(op_def, quantize_op)
# tanh is last op
else:
op_def.name = tanh_op.name + '/QuantizedTanh'
op_def.type = dsp_ops.map_nn_op('Quantized' + tanh_op.type)
op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend([max_elem_size(input_tensor),
max_elem_size(min_tensor),
max_elem_size(max_tensor)])
op_def.output_type.extend([mace_pb2.DT_UINT8, mace_pb2.DT_FLOAT, mace_pb2.DT_FLOAT])
output_shapes = []
for output in first_op.inputs:
output_shape = mace_pb2.OutputShape()
output_shape.dims.extend(output.shape.as_list())
output_shapes.append(output_shape)
op_def.output_shape.extend(output_shapes)
new_tanh_op_def = net_def.op.add()
new_tanh_op_def.name = tanh_op.name
new_tanh_op_def.type = dsp_ops.map_nn_op('Dequantize')
new_tanh_op_def.input.extend([get_tensor_name_from_op(op_def.name, 0),
get_tensor_name_from_op(op_def.name, 1),
get_tensor_name_from_op(op_def.name, 2)])
new_tanh_op_def.out_max_byte_size.extend([max_elem_size(tanh_op.outputs[0])])
convert_op_outputs(new_tanh_op_def, tanh_op)
elif has_padding_and_strides(first_op):
op_def.padding = padding_mode[first_op.get_attr('padding')]
op_def.input.extend([t.name for t in first_op.inputs])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册