diff --git a/tools/python/transform/hexagon_converter.py b/tools/python/transform/hexagon_converter.py index 0988097895ccd1a1f9d6781b612cee2c8e942c46..a83b7f10b679d0d83df6ad4f5f731ea95bf20e10 100644 --- a/tools/python/transform/hexagon_converter.py +++ b/tools/python/transform/hexagon_converter.py @@ -440,36 +440,58 @@ class HexagonConverter(base_converter.ConverterInterface): mace_check( quantize_input_op.type == HexagonOp.QuantizeINPUT_f_to_8.name, "Not started with Quantize op.") + first_quantize_input_op = copy.deepcopy(quantize_input_op) del quantize_input_op.input[:] + del quantize_input_op.output[:] + del quantize_input_op.output_shape[:] + del quantize_input_op.output_type[:] + del quantize_input_op.out_max_byte_size[:] dequantize_output_op = self._model.op[-1] mace_check(dequantize_output_op.type == HexagonOp.DequantizeOUTPUT_8tof.name, "Not ended with Dequantize op.") - dequantize_input = [input for input in dequantize_output_op.input] + last_dequantize_output_op = copy.deepcopy(dequantize_output_op) del dequantize_output_op.input[:] + del dequantize_output_op.output[:] del dequantize_output_op.output_shape[:] del dequantize_output_op.output_type[:] del dequantize_output_op.out_max_byte_size[:] + # Combine multiple inputs/outputs to one hexagon input/output node, + # in input_info/output_info order + ops = {} + for op in self._model.op: + ops[op.name] = op + for input_node in self._option.input_nodes.values(): + op_name = normalize_name( + MaceKeyword.mace_input_node_name + '_' + input_node.name) + op = first_quantize_input_op \ + if op_name == first_quantize_input_op.name else ops[op_name] + mace_check(op.type == HexagonOp.QuantizeINPUT_f_to_8.name, + "input node type is: %s" % op.type) + quantize_input_op.output.extend(op.output) + quantize_input_op.output_shape.extend(op.output_shape) + quantize_input_op.output_type.extend(op.output_type) + quantize_input_op.out_max_byte_size.extend( + op.out_max_byte_size) + for output_node in self._option.check_nodes.values(): + op_name = normalize_name(output_node.name) + op = last_dequantize_output_op \ + if op_name == last_dequantize_output_op.name else ops[op_name] + mace_check(op.type == HexagonOp.DequantizeOUTPUT_8tof.name, + "output node type is: %s" % op.type) + dequantize_output_op.input.extend(op.input) + + # Delete redundant inputs/outputs nodes index = 1 while index < len(self._model.op) - 1: op = self._model.op[index] - if op.type == HexagonOp.QuantizeINPUT_f_to_8.name: - quantize_input_op.output.extend(op.output) - quantize_input_op.output_shape.extend(op.output_shape) - quantize_input_op.output_type.extend(op.output_type) - quantize_input_op.out_max_byte_size.extend( - op.out_max_byte_size) - del self._model.op[index] - - elif op.type == HexagonOp.DequantizeOUTPUT_8tof.name: - dequantize_output_op.input.extend(op.input) + if op.type == HexagonOp.QuantizeINPUT_f_to_8.name \ + or op.type == HexagonOp.DequantizeOUTPUT_8tof.name: del self._model.op[index] - - index += 1 - # input order matters - dequantize_output_op.input.extend(dequantize_input) + else: + index += 1 if self._option.device == DeviceType.HTA.value: # replace QuantizeINPUT_f_to_8 with INPUT