提交 adb538bf 编写于 作者: B Bin Li

Fix quantize info

上级 c1ae5dd3
......@@ -432,7 +432,7 @@ bool HexagonDSPWrapper::ExecuteGraph(const Tensor &input_tensor,
}
MACE_CHECK(output_bytes == output_tensor->raw_size(),
"wrong output bytes inferred.");
return res == 0;
return true;
}
bool HexagonDSPWrapper::ExecuteGraphNew(
......@@ -495,6 +495,7 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
num_inputs * kNumMetaData,
outputs.data(),
num_outputs * kNumMetaData);
MACE_CHECK(res == 0, "execute error");
// handle hexagon output
for (size_t i = 0; i < num_outputs; ++i) {
......@@ -504,12 +505,12 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
outputs[index].depth};
MACE_CHECK(output_shape.size() == output_info_[i].shape.size(),
output_shape.size(), " vs ", output_info_[i].shape.size(),
"wrong output shape inferred");
" wrong output shape inferred");
for (size_t j = 0; j < output_shape.size(); ++j) {
MACE_CHECK(static_cast<index_t>(output_shape[j])
== output_info_[i].shape[j],
output_shape[j], " vs ", output_info_[i].shape[j],
"wrong output shape inferred");
" wrong output shape[", j, "] inferred");
}
auto output_tensor = output_tensors->at(output_info_[i].name);
MACE_CHECK(static_cast<index_t>(outputs[index].data_valid_len)
......@@ -518,7 +519,7 @@ bool HexagonDSPWrapper::ExecuteGraphNew(
" wrong output bytes inferred.");
}
return res == 0;
return true;
}
} // namespace mace
......@@ -24,6 +24,7 @@ from functools import reduce
from python.py_proto import mace_pb2
from python.utils.util import mace_check
from python.utils.util import MaceLogger
from . import base_converter
from .base_converter import ConverterUtil
from .base_converter import DeviceType
......@@ -120,9 +121,9 @@ class HexagonConverter(base_converter.ConverterInterface):
# convert op node
self.convert_ops()
self.convert_input_output_node()
model_inputs = self.convert_input_output_node()
self.add_node_id()
self.add_node_id(model_inputs)
return self._model
......@@ -234,8 +235,11 @@ class HexagonConverter(base_converter.ConverterInterface):
for input_node in self._option.input_nodes.values():
op_name = normalize_name(
MaceKeyword.mace_input_node_name + '_' + input_node.name)
op = first_quantize_input_op \
if op_name == first_quantize_input_op.name else ops[op_name]
if op_name == first_quantize_input_op.name:
op = first_quantize_input_op
quantize_input_op.name = MaceKeyword.mace_input_node_name
else:
op = ops[op_name]
mace_check(op.type == HexagonOp.QuantizeINPUT_f_to_8.name,
"input node type is: %s" % op.type)
quantize_input_op.output.extend(op.output)
......@@ -275,7 +279,9 @@ class HexagonConverter(base_converter.ConverterInterface):
dequantize_output_op.type = HexagonOp.OUTPUT.name
del dequantize_output_op.input[1:]
def add_node_id(self):
return quantize_input_op.output
def add_node_id(self, model_inputs):
node_id_counter = 0
node_id_map = {}
for tensor in self._model.tensors:
......@@ -304,7 +310,11 @@ class HexagonConverter(base_converter.ConverterInterface):
node_id = node_id_map[tensor_name]
node_input = op.node_input.add()
node_input.node_id = node_id
node_input.output_port = int(port)
if tensor_name in model_inputs:
for i in range(len(model_inputs)):
if model_inputs[i] == tensor_name:
port += i * 3
node_input.output_port = port
def convert_ops(self):
print("Convert mace graph to hexagon.")
......
......@@ -36,6 +36,7 @@ from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
from python.utils.util import MaceLogger
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph
......@@ -1078,6 +1079,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
del op.input[1:]
def convert_cumsum(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -37,6 +37,7 @@ from .base_converter import ReduceType
from .base_converter import TransformerRule
from python.quantize import quantize_util
from python.utils.util import mace_check
from python.utils.util import MaceLogger
class Transformer(base_converter.ConverterInterface):
......@@ -1737,6 +1738,7 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
if op.type == 'FakeQuantWithMinMaxVars' or \
op.type == 'FakeQuantWithMinMaxArgs':
if op.input[0] not in self._consts:
producer_op = self._producer[op.input[0]]
minval = ConverterUtil.get_arg(op, 'min').f
maxval = ConverterUtil.get_arg(op, 'max').f
......@@ -1744,7 +1746,8 @@ class Transformer(base_converter.ConverterInterface):
self.add_quantize_info(producer_op, minval, maxval)
self._quantize_activation_info[op.input[0]] = quantize_info
# for add -> fakequant pattern
self._quantize_activation_info[op.output[0]] = quantize_info
self._quantize_activation_info[op.output[0]] = \
quantize_info
print(op.input[0], op.output[0])
op.type = MaceOp.Identity.name
......@@ -1853,6 +1856,8 @@ class Transformer(base_converter.ConverterInterface):
quantize_info.scale = scale
quantize_info.zero_point = zero
self._quantize_activation_info[new_input_name] = quantize_info
input_op = self._producer[input_node.name]
input_op.quantize_info.extend([quantize_info])
print("Add default quantize info for ops like Pooling, Softmax")
for op in self._model.op:
......@@ -1907,8 +1912,8 @@ class Transformer(base_converter.ConverterInterface):
elif (op.type == MaceOp.Eltwise.name
and not op.quantize_info
and len(op.input) == 2
and len(op.input[0]) not in self._consts
and len(op.input[1]) not in self._consts):
and op.input[0] not in self._consts
and op.input[1] not in self._consts):
producer_op0 = self._producer[op.input[0]]
producer_op1 = self._producer[op.input[1]]
if ConverterUtil.get_arg(
......
......@@ -65,8 +65,8 @@ class MaceLogger:
+ CMDColors.ENDC)
@staticmethod
def error(message):
print(CMDColors.RED + 'ERROR: ' + get_frame_info() + str(message)
def error(message, level=2):
print(CMDColors.RED + 'ERROR: ' + get_frame_info(level) + str(message)
+ CMDColors.ENDC)
exit(1)
......@@ -76,7 +76,7 @@ def mace_check(condition, message):
for line in traceback.format_stack():
print(line.strip())
MaceLogger.error(message)
MaceLogger.error(message, level=3)
################################
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册