提交 adb538bf 编写于 作者: B Bin Li

Fix quantize info

上级 c1ae5dd3
...@@ -432,7 +432,7 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor, ...@@ -432,7 +432,7 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor,
} }
MACE_CHECK(output_bytes == output_tensor->raw_size(), MACE_CHECK(output_bytes == output_tensor->raw_size(),
"wrong output bytes inferred."); "wrong output bytes inferred.");
return res == 0; return true;
} }
bool HexagonDSPWrapper::ExecuteGraphNew( bool HexagonDSPWrapper::ExecuteGraphNew(
...@@ -495,6 +495,7 @@ bool HexagonDSPWrapper::ExecuteGraphNew( ...@@ -495,6 +495,7 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
num_inputs * kNumMetaData, num_inputs * kNumMetaData,
outputs.data(), outputs.data(),
num_outputs * kNumMetaData); num_outputs * kNumMetaData);
MACE_CHECK(res == 0, "execute error");
// handle hexagon output // handle hexagon output
for (size_t i = 0; i < num_outputs; ++i) { for (size_t i = 0; i < num_outputs; ++i) {
...@@ -504,12 +505,12 @@ bool HexagonDSPWrapper::ExecuteGraphNew( ...@@ -504,12 +505,12 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
outputs[index].depth}; outputs[index].depth};
MACE_CHECK(output_shape.size() == output_info_[i].shape.size(), MACE_CHECK(output_shape.size() == output_info_[i].shape.size(),
output_shape.size(), " vs ", output_info_[i].shape.size(), output_shape.size(), " vs ", output_info_[i].shape.size(),
"wrong output shape inferred"); " wrong output shape inferred");
for (size_t j = 0; j < output_shape.size(); ++j) { for (size_t j = 0; j < output_shape.size(); ++j) {
MACE_CHECK(static_cast<index_t>(output_shape[j]) MACE_CHECK(static_cast<index_t>(output_shape[j])
== output_info_[i].shape[j], == output_info_[i].shape[j],
output_shape[j], " vs ", output_info_[i].shape[j], output_shape[j], " vs ", output_info_[i].shape[j],
"wrong output shape inferred"); " wrong output shape[", j, "] inferred");
} }
auto output_tensor = output_tensors->at(output_info_[i].name); auto output_tensor = output_tensors->at(output_info_[i].name);
MACE_CHECK(static_cast<index_t>(outputs[index].data_valid_len) MACE_CHECK(static_cast<index_t>(outputs[index].data_valid_len)
...@@ -518,7 +519,7 @@ bool HexagonDSPWrapper::ExecuteGraphNew( ...@@ -518,7 +519,7 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
" wrong output bytes inferred."); " wrong output bytes inferred.");
} }
return res == 0; return true;
} }
} // namespace mace } // namespace mace
...@@ -24,6 +24,7 @@ from functools import reduce ...@@ -24,6 +24,7 @@ from functools import reduce
from python.py_proto import mace_pb2 from python.py_proto import mace_pb2
from python.utils.util import mace_check from python.utils.util import mace_check
from python.utils.util import MaceLogger
from . import base_converter from . import base_converter
from .base_converter import ConverterUtil from .base_converter import ConverterUtil
from .base_converter import DeviceType from .base_converter import DeviceType
...@@ -120,9 +121,9 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -120,9 +121,9 @@ class HexagonConverter(base_converter.ConverterInterface):
# convert op node # convert op node
self.convert_ops() self.convert_ops()
self.convert_input_output_node() model_inputs = self.convert_input_output_node()
self.add_node_id() self.add_node_id(model_inputs)
return self._model return self._model
...@@ -234,8 +235,11 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -234,8 +235,11 @@ class HexagonConverter(base_converter.ConverterInterface):
for input_node in self._option.input_nodes.values(): for input_node in self._option.input_nodes.values():
op_name = normalize_name( op_name = normalize_name(
MaceKeyword.mace_input_node_name + '_' + input_node.name) MaceKeyword.mace_input_node_name + '_' + input_node.name)
op = first_quantize_input_op \ if op_name == first_quantize_input_op.name:
if op_name == first_quantize_input_op.name else ops[op_name] op = first_quantize_input_op
quantize_input_op.name = MaceKeyword.mace_input_node_name
else:
op = ops[op_name]
mace_check(op.type == HexagonOp.QuantizeINPUT_f_to_8.name, mace_check(op.type == HexagonOp.QuantizeINPUT_f_to_8.name,
"input node type is: %s" % op.type) "input node type is: %s" % op.type)
quantize_input_op.output.extend(op.output) quantize_input_op.output.extend(op.output)
...@@ -275,7 +279,9 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -275,7 +279,9 @@ class HexagonConverter(base_converter.ConverterInterface):
dequantize_output_op.type = HexagonOp.OUTPUT.name dequantize_output_op.type = HexagonOp.OUTPUT.name
del dequantize_output_op.input[1:] del dequantize_output_op.input[1:]
def add_node_id(self): return quantize_input_op.output
def add_node_id(self, model_inputs):
node_id_counter = 0 node_id_counter = 0
node_id_map = {} node_id_map = {}
for tensor in self._model.tensors: for tensor in self._model.tensors:
...@@ -304,7 +310,11 @@ class HexagonConverter(base_converter.ConverterInterface): ...@@ -304,7 +310,11 @@ class HexagonConverter(base_converter.ConverterInterface):
node_id = node_id_map[tensor_name] node_id = node_id_map[tensor_name]
node_input = op.node_input.add() node_input = op.node_input.add()
node_input.node_id = node_id node_input.node_id = node_id
node_input.output_port = int(port) if tensor_name in model_inputs:
for i in range(len(model_inputs)):
if model_inputs[i] == tensor_name:
port += i * 3
node_input.output_port = port
def convert_ops(self): def convert_ops(self):
print("Convert mace graph to hexagon.") print("Convert mace graph to hexagon.")
......
...@@ -36,6 +36,7 @@ from .base_converter import MaceOp ...@@ -36,6 +36,7 @@ from .base_converter import MaceOp
from .base_converter import MaceKeyword from .base_converter import MaceKeyword
from .base_converter import ConverterUtil from .base_converter import ConverterUtil
from python.utils.util import mace_check from python.utils.util import mace_check
from python.utils.util import MaceLogger
from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph from tensorflow.tools.graph_transforms import TransformGraph
...@@ -1078,6 +1079,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -1078,6 +1079,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name: if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
self._skip_tensor.add(tf_op.inputs[1].name) self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name) self._skip_tensor.add(tf_op.inputs[2].name)
del op.input[1:]
def convert_cumsum(self, tf_op): def convert_cumsum(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
......
...@@ -37,6 +37,7 @@ from .base_converter import ReduceType ...@@ -37,6 +37,7 @@ from .base_converter import ReduceType
from .base_converter import TransformerRule from .base_converter import TransformerRule
from python.quantize import quantize_util from python.quantize import quantize_util
from python.utils.util import mace_check from python.utils.util import mace_check
from python.utils.util import MaceLogger
class Transformer(base_converter.ConverterInterface): class Transformer(base_converter.ConverterInterface):
...@@ -1737,6 +1738,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1737,6 +1738,7 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op: for op in net.op:
if op.type == 'FakeQuantWithMinMaxVars' or \ if op.type == 'FakeQuantWithMinMaxVars' or \
op.type == 'FakeQuantWithMinMaxArgs': op.type == 'FakeQuantWithMinMaxArgs':
if op.input[0] not in self._consts:
producer_op = self._producer[op.input[0]] producer_op = self._producer[op.input[0]]
minval = ConverterUtil.get_arg(op, 'min').f minval = ConverterUtil.get_arg(op, 'min').f
maxval = ConverterUtil.get_arg(op, 'max').f maxval = ConverterUtil.get_arg(op, 'max').f
...@@ -1744,7 +1746,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1744,7 +1746,8 @@ class Transformer(base_converter.ConverterInterface):
self.add_quantize_info(producer_op, minval, maxval) self.add_quantize_info(producer_op, minval, maxval)
self._quantize_activation_info[op.input[0]] = quantize_info self._quantize_activation_info[op.input[0]] = quantize_info
# for add -> fakequant pattern # for add -> fakequant pattern
self._quantize_activation_info[op.output[0]] = quantize_info self._quantize_activation_info[op.output[0]] = \
quantize_info
print(op.input[0], op.output[0]) print(op.input[0], op.output[0])
op.type = MaceOp.Identity.name op.type = MaceOp.Identity.name
...@@ -1853,6 +1856,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1853,6 +1856,8 @@ class Transformer(base_converter.ConverterInterface):
quantize_info.scale = scale quantize_info.scale = scale
quantize_info.zero_point = zero quantize_info.zero_point = zero
self._quantize_activation_info[new_input_name] = quantize_info self._quantize_activation_info[new_input_name] = quantize_info
input_op = self._producer[input_node.name]
input_op.quantize_info.extend([quantize_info])
print("Add default quantize info for ops like Pooling, Softmax") print("Add default quantize info for ops like Pooling, Softmax")
for op in self._model.op: for op in self._model.op:
...@@ -1907,8 +1912,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1907,8 +1912,8 @@ class Transformer(base_converter.ConverterInterface):
elif (op.type == MaceOp.Eltwise.name elif (op.type == MaceOp.Eltwise.name
and not op.quantize_info and not op.quantize_info
and len(op.input) == 2 and len(op.input) == 2
and len(op.input[0]) not in self._consts and op.input[0] not in self._consts
and len(op.input[1]) not in self._consts): and op.input[1] not in self._consts):
producer_op0 = self._producer[op.input[0]] producer_op0 = self._producer[op.input[0]]
producer_op1 = self._producer[op.input[1]] producer_op1 = self._producer[op.input[1]]
if ConverterUtil.get_arg( if ConverterUtil.get_arg(
......
...@@ -65,8 +65,8 @@ class MaceLogger: ...@@ -65,8 +65,8 @@ class MaceLogger:
+ CMDColors.ENDC) + CMDColors.ENDC)
@staticmethod @staticmethod
def error(message): def error(message, level=2):
print(CMDColors.RED + 'ERROR: ' + get_frame_info() + str(message) print(CMDColors.RED + 'ERROR: ' + get_frame_info(level) + str(message)
+ CMDColors.ENDC) + CMDColors.ENDC)
exit(1) exit(1)
...@@ -76,7 +76,7 @@ def mace_check(condition, message): ...@@ -76,7 +76,7 @@ def mace_check(condition, message):
for line in traceback.format_stack(): for line in traceback.format_stack():
print(line.strip()) print(line.strip())
MaceLogger.error(message) MaceLogger.error(message, level=3)
################################ ################################
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册