提交 6f5e0613 编写于 作者: L liyin

Fix hexagon patch

上级 320b509c
......@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from enum import Enum
from operator import mul
from functools import reduce
from py_proto import mace_pb2
from transform import base_converter
......@@ -29,8 +35,6 @@ from transform.base_converter import PoolingType
from transform.base_converter import ReduceType
from utils.util import mace_check
from six.moves import reduce
HexagonSupportedOps = [
'BatchToSpaceND_8',
......@@ -143,18 +147,23 @@ class HexagonConverter(base_converter.ConverterInterface):
return self._model
def add_port_for_tensors(self, tensors):
for i in range(len(tensors)):
if ':' not in tensors[i]:
node_name = tensors[i]
tensors[i] += ':0'
if node_name in self._quantize_activation_info:
self._quantize_activation_info[tensors[i]] = \
self._quantize_activation_info[node_name]
def convert_ops(self):
print("Convert mace graph to hexagon.")
for op in self._model.op:
if not self._hexagon_ops.has_op(op.type):
raise Exception('Unsupported op: ', op)
for i in range(len(op.input)):
if ':' not in op.input[i]:
node_name = op.input[i]
op.input[i] += ':0'
if node_name in self._quantize_activation_info:
self._quantize_activation_info[op.input[i]] = \
self._quantize_activation_info[node_name]
self.add_port_for_tensors(op.input)
self.add_port_for_tensors(op.output)
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
......@@ -482,13 +491,15 @@ class HexagonConverter(base_converter.ConverterInterface):
for tensor in self._model.tensors:
tensor.node_id = node_id_counter
node_id_counter += 1
tensor_op, port = get_op_and_port_from_tensor(tensor.name)
node_id_map[tensor_op] = tensor.node_id
node_id_map[tensor.name] = tensor.node_id
print("Hexagon op:")
index = 0
for op in self._model.op:
op.node_id = node_id_counter
node_id_counter += 1
for output in op.output:
node_id_map[output] = op.node_id
if op.type not in [HexagonOp.QuantizeINPUT_f_to_8,
HexagonOp.DequantizeOUTPUT_8tof.name]:
index_str = str(index)
......@@ -497,11 +508,10 @@ class HexagonConverter(base_converter.ConverterInterface):
index_str = ''
print('Op: %s (%s, node_id:%d, index:%s)' %
(op.name, op.type, op.node_id, index_str))
node_id_counter += 1
node_id_map[op.name] = op.node_id
for ipt in op.input:
op_name, port = get_op_and_port_from_tensor(ipt)
node_id = node_id_map[op_name]
tensor_name = ipt if port == 0 else op_name + ':0'
node_id = node_id_map[tensor_name]
node_input = op.node_input.add()
node_input.node_id = node_id
node_input.output_port = int(port)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册