提交 06c5b6c9 编写于 作者: W wuchenghui

add softmax support for dsp

上级 0f0ec44d
......@@ -21,6 +21,7 @@ class DspOps(object):
'QuantizedResizeBilinear' : 'QuantizedResizeBilinear_8',
'QuantizedSpaceToBatchND': 'QuantizedSpaceToBatchND_8',
'QuantizedBatchToSpaceND': 'QuantizedBatchToSpaceND_8',
'QuantizedSoftmax': 'QuantizedSoftmax_8',
'Min': 'Min_f',
'Max': 'Max_f',
'QuantizeV2': 'Quantize',
......
......@@ -134,6 +134,35 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
op_def.input.extend([min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_op.outputs])
convert_op_outputs(op_def, quantize_op)
elif len(first_op.outputs) > 0 and first_op.type == 'QuantizedReshape' \
and len(first_op.outputs[0].consumers()) > 0 \
and first_op.outputs[0].consumers()[0].type == 'Dequantize' \
and len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) > 0 \
and first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type == 'Softmax':
input_tensor = first_op.inputs[0]
min_tensor = first_op.inputs[2]
max_tensor = first_op.inputs[3]
dequantize_op = first_op.outputs[0].consumers()[0]
softmax_op = dequantize_op.outputs[0].consumers()[0]
reshape_op = softmax_op.outputs[0].consumers()[0]
min_op = reshape_op.outputs[0].consumers()[0]
max_op = reshape_op.outputs[0].consumers()[1]
quantize_op = min_op.outputs[0].consumers()[0]
quantize_reshape_op = quantize_op.outputs[0].consumers()[0]
resolved_ops.add(dequantize_op.name)
resolved_ops.add(softmax_op.name)
resolved_ops.add(reshape_op.name)
resolved_ops.add(min_op.name)
resolved_ops.add(max_op.name)
resolved_ops.add(quantize_op.name)
resolved_ops.add(quantize_reshape_op.name)
op_def.name = quantize_reshape_op.name
op_def.type = dsp_ops.map_nn_op('QuantizedSoftmax')
op_def.input.extend([input_tensor.name, min_tensor.name, max_tensor.name])
op_def.out_max_byte_size.extend([max_elem_size(out) for out in quantize_reshape_op.outputs])
convert_op_outputs(op_def, quantize_reshape_op)
elif has_padding_and_strides(first_op):
op_def.padding = padding_mode[first_op.get_attr('padding')]
op_def.input.extend([t.name for t in first_op.inputs])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册