提交 6f5e0613 编写于 作者: L liyin

Fix hexagon patch

上级 320b509c
...@@ -12,10 +12,16 @@ ...@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy import copy
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from operator import mul from operator import mul
from functools import reduce
from py_proto import mace_pb2 from py_proto import mace_pb2
from transform import base_converter from transform import base_converter
...@@ -29,8 +35,6 @@ from transform.base_converter import PoolingType ...@@ -29,8 +35,6 @@ from transform.base_converter import PoolingType
from transform.base_converter import ReduceType from transform.base_converter import ReduceType
from utils.util import mace_check from utils.util import mace_check
from six.moves import reduce
HexagonSupportedOps = [ HexagonSupportedOps = [
'BatchToSpaceND_8', 'BatchToSpaceND_8',
...@@ -143,18 +147,23 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -143,18 +147,23 @@ class HexagonConverter(base_converter.ConverterInterface):
return self._model return self._model
def add_port_for_tensors(self, tensors):
for i in range(len(tensors)):
if ':' not in tensors[i]:
node_name = tensors[i]
tensors[i] += ':0'
if node_name in self._quantize_activation_info:
self._quantize_activation_info[tensors[i]] = \
self._quantize_activation_info[node_name]
def convert_ops(self): def convert_ops(self):
print("Convert mace graph to hexagon.") print("Convert mace graph to hexagon.")
for op in self._model.op: for op in self._model.op:
if not self._hexagon_ops.has_op(op.type): if not self._hexagon_ops.has_op(op.type):
raise Exception('Unsupported op: ', op) raise Exception('Unsupported op: ', op)
for i in range(len(op.input)):
if ':' not in op.input[i]: self.add_port_for_tensors(op.input)
node_name = op.input[i] self.add_port_for_tensors(op.output)
op.input[i] += ':0'
if node_name in self._quantize_activation_info:
self._quantize_activation_info[op.input[i]] = \
self._quantize_activation_info[node_name]
if op.type == MaceOp.Conv2D.name \ if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name: or op.type == MaceOp.DepthwiseConv2d.name:
...@@ -482,13 +491,15 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -482,13 +491,15 @@ class HexagonConverter(base_converter.ConverterInterface):
for tensor in self._model.tensors: for tensor in self._model.tensors:
tensor.node_id = node_id_counter tensor.node_id = node_id_counter
node_id_counter += 1 node_id_counter += 1
tensor_op, port = get_op_and_port_from_tensor(tensor.name) node_id_map[tensor.name] = tensor.node_id
node_id_map[tensor_op] = tensor.node_id
print("Hexagon op:") print("Hexagon op:")
index = 0 index = 0
for op in self._model.op: for op in self._model.op:
op.node_id = node_id_counter op.node_id = node_id_counter
node_id_counter += 1
for output in op.output:
node_id_map[output] = op.node_id
if op.type not in [HexagonOp.QuantizeINPUT_f_to_8, if op.type not in [HexagonOp.QuantizeINPUT_f_to_8,
HexagonOp.DequantizeOUTPUT_8tof.name]: HexagonOp.DequantizeOUTPUT_8tof.name]:
index_str = str(index) index_str = str(index)
...@@ -497,11 +508,10 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -497,11 +508,10 @@ class HexagonConverter(base_converter.ConverterInterface):
index_str = '' index_str = ''
print('Op: %s (%s, node_id:%d, index:%s)' % print('Op: %s (%s, node_id:%d, index:%s)' %
(op.name, op.type, op.node_id, index_str)) (op.name, op.type, op.node_id, index_str))
node_id_counter += 1
node_id_map[op.name] = op.node_id
for ipt in op.input: for ipt in op.input:
op_name, port = get_op_and_port_from_tensor(ipt) op_name, port = get_op_and_port_from_tensor(ipt)
node_id = node_id_map[op_name] tensor_name = ipt if port == 0 else op_name + ':0'
node_id = node_id_map[tensor_name]
node_input = op.node_input.add() node_input = op.node_input.add()
node_input.node_id = node_id node_input.node_id = node_id
node_input.output_port = int(port) node_input.output_port = int(port)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册