From 6f5e0613dec8ca0cded06bd9634619115f80f166 Mon Sep 17 00:00:00 2001 From: liyin Date: Mon, 12 Aug 2019 15:42:58 +0800 Subject: [PATCH] Fix hexagon patch --- tools/python/transform/hexagon_converter.py | 38 +++++++++++++-------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/tools/python/transform/hexagon_converter.py b/tools/python/transform/hexagon_converter.py index ca5aff20..3cbd6f2f 100644 --- a/tools/python/transform/hexagon_converter.py +++ b/tools/python/transform/hexagon_converter.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + import copy import numpy as np from enum import Enum from operator import mul +from functools import reduce from py_proto import mace_pb2 from transform import base_converter @@ -29,8 +35,6 @@ from transform.base_converter import PoolingType from transform.base_converter import ReduceType from utils.util import mace_check -from six.moves import reduce - HexagonSupportedOps = [ 'BatchToSpaceND_8', @@ -143,18 +147,23 @@ class HexagonConverter(base_converter.ConverterInterface): return self._model + def add_port_for_tensors(self, tensors): + for i in range(len(tensors)): + if ':' not in tensors[i]: + node_name = tensors[i] + tensors[i] += ':0' + if node_name in self._quantize_activation_info: + self._quantize_activation_info[tensors[i]] = \ + self._quantize_activation_info[node_name] + def convert_ops(self): print("Convert mace graph to hexagon.") for op in self._model.op: if not self._hexagon_ops.has_op(op.type): raise Exception('Unsupported op: ', op) - for i in range(len(op.input)): - if ':' not in op.input[i]: - node_name = op.input[i] - op.input[i] += ':0' - if node_name in self._quantize_activation_info: - self._quantize_activation_info[op.input[i]] = \ - self._quantize_activation_info[node_name] + + self.add_port_for_tensors(op.input) + self.add_port_for_tensors(op.output) if op.type == MaceOp.Conv2D.name \ or op.type == MaceOp.DepthwiseConv2d.name: @@ -482,13 +491,15 @@ class HexagonConverter(base_converter.ConverterInterface): for tensor in self._model.tensors: tensor.node_id = node_id_counter node_id_counter += 1 - tensor_op, port = get_op_and_port_from_tensor(tensor.name) - node_id_map[tensor_op] = tensor.node_id + node_id_map[tensor.name] = tensor.node_id print("Hexagon op:") index = 0 for op in self._model.op: op.node_id = node_id_counter + node_id_counter += 1 + for output in op.output: + node_id_map[output] = op.node_id if op.type not in [HexagonOp.QuantizeINPUT_f_to_8, HexagonOp.DequantizeOUTPUT_8tof.name]: index_str = str(index) @@ -497,11 +508,10 @@ class HexagonConverter(base_converter.ConverterInterface): index_str = '' print('Op: %s (%s, node_id:%d, index:%s)' % (op.name, op.type, op.node_id, index_str)) - node_id_counter += 1 - node_id_map[op.name] = op.node_id for ipt in op.input: op_name, port = get_op_and_port_from_tensor(ipt) - node_id = node_id_map[op_name] + tensor_name = ipt if port == 0 else op_name + ':0' + node_id = node_id_map[tensor_name] node_input = op.node_input.add() node_input.node_id = node_id node_input.output_port = int(port) -- GitLab