提交 7c10b286 编写于 作者: B Bin Li

Support unet for hexagon

上级 46d3bd53
......@@ -32,10 +32,20 @@ Post training quantization
---------------------------
MACE supports post-training quantization if you want to take a chance to quantize model directly without fine tuning.
This method requires developer to calculate tensor range of each activation layer statistically using sample inputs.
MACE provides tools to do statistics with following steps:
MACE provides tools to do statistics with following steps(using `inception-v3` from `MACE Model Zoo <https://github.com/XiaoMi/mace-models>`__ as an example):
1. Convert original model to run on CPU host without obfuscation (by setting `target_abis` to `host`, `runtime` to `cpu`,
and `obfuscate` to `0`, appending `:0` to `output_tensors` if missing in yaml config).
and `obfuscate` to `0`, appending `:0` to `output_tensors` if missing in yaml config).
.. code-block:: sh
# For CMake users:
python tools/python/convert.py --config ../mace-models/inception-v3/inception-v3.yml
--quantize_stat
# For Bazel users:
python tools/converter.py convert --config ../mace-models/inception-v3/inception-v3.yml
--quantize_stat
2. Log tensor range of each activation layer by inferring several samples on CPU host. Sample inputs should be
representative to calculate the ranges of each layer properly.
......
......@@ -414,6 +414,7 @@ class YAMLKeyword(object):
quantize = 'quantize'
quantize_large_weights = 'quantize_large_weights'
quantize_range_file = 'quantize_range_file'
quantize_stat = 'quantize_stat'
change_concat_ranges = 'change_concat_ranges'
validation_inputs_data = 'validation_inputs_data'
validation_threshold = 'validation_threshold'
......
......@@ -793,6 +793,9 @@ def convert_func(flags):
if os.path.exists(ENGINE_CODEGEN_DIR):
sh.rm("-rf", ENGINE_CODEGEN_DIR)
if flags.quantize_stat:
configs[YAMLKeyword.quantize_stat] = flags.quantize_stat
if flags.model_data_format:
model_data_format = flags.model_data_format
else:
......@@ -1017,6 +1020,10 @@ def parse_args():
'--address_sanitizer',
action="store_true",
help="Whether to use address sanitizer to check memory error")
convert_run_parent_parser.add_argument(
"--quantize_stat",
action="store_true",
help="whether to stat quantization range.")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
......@@ -1121,10 +1128,6 @@ def parse_args():
type=float,
default=0.0,
help="[mock runtime failure ratio].")
run.add_argument(
"--quantize_stat",
action="store_true",
help="whether to stat quantization range.")
run.add_argument(
"--input_dir",
type=str,
......
......@@ -46,6 +46,10 @@ def transpose_shape(shape, dst_order):
def convert(conf, output):
if ModelKeys.quantize_stat in conf:
quantize_stat = conf[ModelKeys.quantize_stat]
else:
quantize_stat = False
for model_name, model_conf in conf["models"].items():
model_output = output + "/" + model_name + "/model"
org_model_dir = output + "/" + model_name + "/org_model"
......@@ -72,7 +76,7 @@ def convert(conf, output):
"", model_output)
model_conf[ModelKeys.quantize_range_file] = range_file
mace_model = convert_model(model_conf)
mace_model = convert_model(model_conf, quantize_stat)
try:
visualizer = visualize_model.ModelVisualizer(model_name,
......@@ -95,9 +99,10 @@ def convert(conf, output):
f.write(str(model))
def convert_model(conf):
def convert_model(conf, quantize_stat):
option = cvt.ConverterOption()
option.quantize_stat = quantize_stat
if ModelKeys.graph_optimize_options in conf:
option.transformer_option = conf[ModelKeys.graph_optimize_options]
if ModelKeys.winograd in conf:
......
......@@ -404,6 +404,7 @@ class ConverterOption(object):
self._change_concat_ranges = False
self._transformer_option = None
self._cl_mem_type = "image"
self._quantize_stat = False
@property
def input_nodes(self):
......@@ -453,6 +454,10 @@ class ConverterOption(object):
def cl_mem_type(self):
return self._cl_mem_type
@property
def quantize_stat(self):
return self._quantize_stat
@input_nodes.setter
def input_nodes(self, input_nodes):
for node in input_nodes.values():
......@@ -513,6 +518,10 @@ class ConverterOption(object):
def cl_mem_type(self, cl_mem_type):
self._cl_mem_type = cl_mem_type
@quantize_stat.setter
def quantize_stat(self, quantize_stat):
self._quantize_stat = quantize_stat
def disable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)
......
......@@ -25,12 +25,14 @@ from functools import reduce
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import ActivationType
from transform.base_converter import ConverterUtil
from transform.base_converter import DeviceType
from transform.base_converter import EltwiseType
from transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp
from transform.base_converter import PaddingMode
from transform.base_converter import PadType
from transform.base_converter import PoolingType
from transform.base_converter import ReduceType
from utils.util import mace_check
......@@ -47,10 +49,19 @@ HexagonSupportedOps = [
'QuantizedAvgPool_8',
'QuantizedConcat_8',
'QuantizedMaxPool_8',
'QuantizedMaximum_8',
'QuantizedMinimum_8',
'QuantizedMul_8x8to8',
'QuantizedPad_8',
'QuantizedRelu_8',
'QuantizedReluX_8',
'QuantizedReshape',
'QuantizedResizeBilinear_8',
'QuantizedSigmoid_8',
'QuantizedSoftmax_8',
'QuantizedStridedSlice_8',
'QuantizedSub_8p8to8',
'QuantizedTanh_8',
'QuantizedTransposeConv2d_8x8p32to8',
'QuantizeINPUT_f_to_8',
'SpaceToBatchND_8',
......@@ -88,12 +99,28 @@ def normalize_name(name):
class HexagonConverter(base_converter.ConverterInterface):
activation_type = {
ActivationType.RELU.name: HexagonOp.QuantizedRelu_8.name,
ActivationType.RELUX.name: HexagonOp.QuantizedReluX_8.name,
ActivationType.TANH.name: HexagonOp.QuantizedTanh_8.name,
ActivationType.SIGMOID.name: HexagonOp.QuantizedSigmoid_8.name,
}
eltwise_type = {
EltwiseType.SUM.value: HexagonOp.QuantizedAdd_8p8to8.name,
EltwiseType.SUB.value: HexagonOp.QuantizedSub_8p8to8.name,
EltwiseType.PROD.value: HexagonOp.QuantizedMul_8x8to8.name,
EltwiseType.MIN.value: HexagonOp.QuantizedMinimum_8.name,
EltwiseType.MAX.value: HexagonOp.QuantizedMaximum_8.name,
}
def __init__(self, option, model, quantize_activation_info):
self._option = option
self._model = model
self._consts = {}
self._quantize_activation_info = quantize_activation_info
self._op_converters = {
MaceOp.Activation.name: self.convert_activation,
MaceOp.BatchToSpaceND.name: self.convert_batchspace,
MaceOp.Concat.name: self.convert_concat,
MaceOp.Conv2D.name: self.convert_conv2d,
......@@ -102,11 +129,14 @@ class HexagonConverter(base_converter.ConverterInterface):
MaceOp.DepthwiseConv2d.name: self.convert_conv2d,
MaceOp.Dequantize.name: self.convert_dequantize,
MaceOp.Eltwise.name: self.convert_elementwise,
MaceOp.ExpandDims.name: self.convert_expanddims,
MaceOp.Pad.name: self.convert_pad,
MaceOp.Pooling.name: self.convert_pooling,
MaceOp.Quantize.name: self.convert_quantize,
MaceOp.Reduce.name: self.convert_reduce,
MaceOp.ResizeBilinear.name: self.convert_resizebilinear,
MaceOp.Softmax.name: self.convert_softmax,
MaceOp.StridedSlice.name: self.convert_stridedslice,
MaceOp.SpaceToBatchND.name: self.convert_batchspace,
MaceOp.SpaceToDepth.name: self.convert_depthspace,
}
......@@ -133,7 +163,10 @@ class HexagonConverter(base_converter.ConverterInterface):
self._quantize_activation_info[tensors[i]] = \
self._quantize_activation_info[node_name]
def add_scalar_const_node(self, name, val):
def add_scalar_const_node(self, name, val, op=None):
if op is not None:
name = op.name + name
op.input.append(name)
if name not in self._consts:
tensor = self._model.tensors.add()
self._consts[name] = tensor
......@@ -332,6 +365,21 @@ class HexagonConverter(base_converter.ConverterInterface):
if arg is not None:
op.padding = padding_mode[PaddingMode(arg.i)]
def convert_activation(self, op):
self.add_min_max_const_node(op, op.input[0])
act_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_activation_type_str).s.decode()
if act_type == ActivationType.RELUX.name:
x = ConverterUtil.get_arg(
op, MaceKeyword.mace_activation_max_limit_str).f
self.add_scalar_const_node("/x:0", x, op)
try:
op.type = self.activation_type[act_type]
except KeyError:
mace_check(False,
"Hexagon does not support activation %s" % act_type)
def convert_batchspace(self, op):
strides_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_space_batch_block_shape_str)
......@@ -484,27 +532,53 @@ class HexagonConverter(base_converter.ConverterInterface):
op.type = HexagonOp.DequantizeOUTPUT_8tof.name
def convert_elementwise(self, op):
if len(op.input) == 1:
scalar_input_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_scalar_input_str)
self.add_scalar_const_node("/b:0", scalar_input_arg.i, op)
self.add_min_max_const_node(op, op.input[0])
self.add_min_max_const_node(op, op.input[1])
element_type = \
ConverterUtil.get_arg(op,
MaceKeyword.mace_element_type_str).i
if element_type == EltwiseType.SUM.value:
element_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_element_type_str).i
if element_type in [EltwiseType.SUM.value,
EltwiseType.SUB.value,
EltwiseType.MIN.value,
EltwiseType.MAX.value]:
self.add_min_max_const_node(
op, op.output[0], True, True, False)
op.type = HexagonOp.QuantizedAdd_8p8to8.name
elif element_type == EltwiseType.SUB.value:
self.add_min_max_const_node(
op, op.output[0], True, True, False)
op.type = HexagonOp.QuantizedSub_8p8to8.name
elif element_type == EltwiseType.PROD.value:
op.type = HexagonOp.QuantizedMul_8x8to8.name
else:
try:
op.type = self.eltwise_type[element_type]
except KeyError:
mace_check(False,
"Hexagon does not support elementwise %s"
% EltwiseType(element_type).name)
def convert_expanddims(self, op):
shape = op.output_shape[0].dims
self.add_arg_const_node(op, '/shape:0', [len(shape)], shape)
self.add_min_max_const_node(op, op.input[0])
# Convert to reshape as hexagon does not support quantized expanddims
op.type = HexagonOp.QuantizedReshape.name
def convert_pad(self, op):
self.add_min_max_const_node(op, op.input[0])
paddings = ConverterUtil.get_arg(
op, MaceKeyword.mace_paddings_str).ints
self.add_arg_const_node(
op, '/paddings:0', [1, 1, len(paddings) // 2, 2], paddings)
pad_type = ConverterUtil.get_arg(op, MaceKeyword.mace_pad_type_str).i
mace_check(pad_type == PadType.CONSTANT.value,
"Hexagon only supports constant pad")
constant_value = ConverterUtil.get_arg(
op, MaceKeyword.mace_constant_value_str).f
self.add_scalar_const_node('/constant_value:0', constant_value, op)
op.type = HexagonOp.QuantizedPad_8.name
def convert_pooling(self, op):
self.add_min_max_const_node(op, op.input[0])
......@@ -579,3 +653,17 @@ class HexagonConverter(base_converter.ConverterInterface):
self.add_min_max_const_node(op, op.input[0])
op.type = HexagonOp.QuantizedSoftmax_8.name
def convert_stridedslice(self, op):
beigin_mask = ConverterUtil.get_arg(
op, MaceKeyword.mace_begin_mask_str).i
end_mask = ConverterUtil.get_arg(
op, MaceKeyword.mace_end_mask_str).i
shrink_mask = ConverterUtil.get_arg(
op, MaceKeyword.mace_shrink_axis_mask_str).i
self.add_scalar_const_node("/begin_mask:0", beigin_mask, op)
self.add_scalar_const_node("/end_mask:0", end_mask, op)
self.add_scalar_const_node("/shrink_mask:0", shrink_mask, op)
self.add_min_max_const_node(op, op.input[0])
op.type = HexagonOp.QuantizedStridedSlice_8.name
......@@ -20,6 +20,7 @@ import six
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import ActivationType
from transform.base_converter import ConverterUtil
from transform.base_converter import DataFormat
from transform.base_converter import DeviceType
......@@ -960,10 +961,17 @@ class Transformer(base_converter.ConverterInterface):
or op.type == MaceOp.BatchNorm.name) \
and len(self._consumers.get(op.output[0], [])) == 1:
consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.Activation.name \
and ConverterUtil.get_arg(
consumer_op,
MaceKeyword.mace_activation_type_str).s != b'PRELU': # noqa
if consumer_op.type == MaceOp.Activation.name:
act_type_arg = ConverterUtil.get_arg(
consumer_op, MaceKeyword.mace_activation_type_str)
act_type = act_type_arg.s.decode()
if act_type == ActivationType.PRELU.name:
continue
# during quantization, only fold relu/relux
if (self._option.quantize_stat or self._option.quantize) \
and act_type not in [ActivationType.RELU.name,
ActivationType.RELUX.name]:
continue
print("Fold activation: %s(%s)" % (op.name, op.type))
op.name = consumer_op.name
op.output[0] = consumer_op.output[0]
......@@ -1886,11 +1894,14 @@ class Transformer(base_converter.ConverterInterface):
print("Add default quantize info for ops like Pooling, Softmax")
for op in self._model.op:
if op.type in [MaceOp.Pooling.name,
if op.type in [MaceOp.ExpandDims.name,
MaceOp.Pad.name,
MaceOp.Pooling.name,
MaceOp.Reduce.name,
MaceOp.Squeeze.name,
MaceOp.Reshape.name,
MaceOp.ResizeBilinear.name,
MaceOp.Squeeze.name,
MaceOp.StridedSlice.name,
MaceOp.BatchToSpaceND.name,
MaceOp.SpaceToBatchND.name,
MaceOp.SpaceToDepth.name,
......@@ -1929,6 +1940,18 @@ class Transformer(base_converter.ConverterInterface):
self.copy_quantize_info(producer_op, quantize_info)
self._quantize_activation_info[producer_op.output[0]] \
= producer_op.quantize_info[0]
elif op.type == MaceOp.Activation.name:
act_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_activation_type_str).s.decode()
if act_type not in [ActivationType.TANH.name,
ActivationType.SIGMOID.name]:
continue
del op.quantize_info[:]
if act_type == ActivationType.TANH.name:
quantize_info = self.add_quantize_info(op, -1.0, 1.0)
else:
quantize_info = self.add_quantize_info(op, 0.0, 1.0)
self._quantize_activation_info[op.output[0]] = quantize_info
elif op.type == MaceOp.Softmax.name:
del op.quantize_info[:]
quantize_info = \
......
......@@ -92,6 +92,7 @@ class ModelKeys(object):
quantize_range_file = "quantize_range_file"
quantize = "quantize"
quantize_large_weights = "quantize_large_weights"
quantize_stat = "quantize_stat"
change_concat_ranges = "change_concat_ranges"
winograd = "winograd"
cl_mem_type = "cl_mem_type"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册