From 0c6ec590668e93fb921400809bca9283fce5cc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 10 Sep 2018 18:14:26 +0800 Subject: [PATCH] Move quantize info into op's output --- mace/core/tensor.h | 31 +++- mace/core/workspace.cc | 26 +-- mace/kernels/conv_2d.h | 11 ++ mace/proto/mace.proto | 19 +- mace/python/tools/converter.py | 12 ++ .../tools/converter_tool/base_converter.py | 20 +- .../converter_tool/tensorflow_converter.py | 20 ++ .../tools/converter_tool/transformer.py | 172 +++++++++++++++--- mace/python/tools/memory_optimizer.py | 16 +- mace/python/tools/model.jinja2 | 16 -- mace/python/tools/operator.jinja2 | 10 + tools/converter.py | 1 + tools/sh_commands.py | 2 + 13 files changed, 283 insertions(+), 73 deletions(-) diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 5c4c807b..a497309f 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -109,7 +109,9 @@ class Tensor { name_(""), is_weight_(is_weight), scale_(0.f), - zero_point_(0) {} + zero_point_(0), + minval_(0.f), + maxval_(0.f) {} Tensor(BufferBase *buffer, DataType dtype, bool is_weight = false) @@ -120,7 +122,9 @@ class Tensor { name_(""), is_weight_(is_weight), scale_(0.f), - zero_point_(0) {} + zero_point_(0), + minval_(0.f), + maxval_(0.f) {} Tensor(const BufferSlice &buffer_slice, DataType dtype, @@ -132,7 +136,9 @@ class Tensor { name_(""), is_weight_(is_weight), scale_(0.f), - zero_point_(0) { + zero_point_(0), + minval_(0.f), + maxval_(0.f) { buffer_ = &buffer_slice_; } @@ -391,6 +397,15 @@ class Tensor { return zero_point_; } + // hexagon now uses min/max instead of scale and zero + inline float minval() const { + return minval_; + } + + inline float maxval() const { + return maxval_; + } + inline void SetScale(float scale) { scale_ = scale; } @@ -403,6 +418,14 @@ class Tensor { is_weight_ = is_weight; } + inline void SetMinVal(float minval) { + minval_ = minval; + } + + inline void SetMaxVal(float maxval) { + maxval_ = maxval; + } + private: Allocator *allocator_; DataType dtype_; @@ -416,6 +439,8 @@ class Tensor { bool is_weight_; float scale_; int32_t zero_point_; + float minval_; + float maxval_; MACE_DISABLE_COPY_AND_ASSIGN(Tensor); }; diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 4c9204cb..07480ad4 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -178,16 +178,19 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, if (status != MaceStatus::MACE_SUCCESS) return status; } - if (device_type == DeviceType::CPU && net_def.has_quantize_info()) { - for (const auto - &activation_info: net_def.quantize_info().activation_info()) { - if (HasTensor(activation_info.tensor_name())) { - Tensor *tensor = GetTensor(activation_info.tensor_name()); - tensor->SetScale(activation_info.scale()); - tensor->SetZeroPoint(activation_info.zero_point()); - } else { - LOG(WARNING) << "Quantize info exists for non-existed tensor: " - << activation_info.tensor_name(); + if (device_type == DeviceType::CPU) { + for (const auto &op : net_def.op()) { + VLOG(2) << "Add quantize info for op: " << op.name(); + MACE_CHECK(op.quantize_info().empty() + || op.quantize_info().size() == op.output().size(), + "quantize info size must be equal to output size or empty"); + for (int i = 0; i < op.quantize_info().size(); ++i) { + auto &quantize_info = op.quantize_info(i); + Tensor *tensor = GetTensor(op.output(i)); + tensor->SetScale(quantize_info.scale()); + tensor->SetZeroPoint(quantize_info.zero_point()); + tensor->SetMinVal(quantize_info.minval()); + tensor->SetMaxVal(quantize_info.maxval()); } } } @@ -233,8 +236,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, std::unique_ptr tensor_buf( new Buffer(GetCPUAllocator())); MACE_RETURN_IF_ERROR(tensor_buf->Allocate( - mem_block.x() * GetEnumTypeSize(dtype) - + MACE_EXTRA_BUFFER_PAD_SIZE)); + mem_block.x() + MACE_EXTRA_BUFFER_PAD_SIZE)); preallocated_allocator_.SetBuffer(mem_block.mem_id(), std::move(tensor_buf)); } else if (mem_block.mem_type() == MemoryType::GPU_IMAGE) { diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index ce9bb11d..024644f3 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -879,6 +879,17 @@ struct Conv2dFunctor : Conv2dFunctorBase { const index_t depth = input_channels * filter_h * filter_w; const index_t columns = batch * height * width; + VLOG(2) << "input scale/zero: " << input->scale() << ", " + << input->zero_point(); + VLOG(2) << "filter scale/zero: " << filter->scale() << ", " + << filter->zero_point(); + if (bias) { + VLOG(2) << "bias scale/zero: " << bias->scale() << ", " + << bias->zero_point(); + } + VLOG(2) << "output scale/zero: " << output->scale() << ", " + << output->zero_point(); + MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels); MACE_CHECK(filter->dim(3) == input_channels, filter->dim(3), " != ", input_channels); diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 4e4b6a07..9ec3d96e 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -59,6 +59,13 @@ message OutputShape { repeated int64 dims = 1; } +message QuantizeActivationInfo { + optional float scale = 1; + optional int32 zero_point = 2; + optional float minval = 3; // hexagon uses min/max + optional float maxval = 4; +} + message OperatorDef { repeated string input = 1; repeated string output = 2; @@ -67,6 +74,7 @@ message OperatorDef { repeated Argument arg = 5; repeated OutputShape output_shape = 6; repeated DataType output_type = 7; + repeated QuantizeActivationInfo quantize_info = 8; repeated int32 mem_id = 10; @@ -106,23 +114,12 @@ message OutputInfo { optional DataType data_type = 5 [default = DT_FLOAT]; } -message QuantizeActivationInfo { - optional string tensor_name = 1; - optional float scale = 2; - optional int32 zero_point = 3; -} - -message QuantizeInfo { - repeated QuantizeActivationInfo activation_info = 1; -} - message NetDef { optional string name = 1; repeated OperatorDef op = 2; optional string version = 3; repeated Argument arg = 4; repeated ConstTensor tensors = 5; - optional QuantizeInfo quantize_info = 6; // for mem optimization optional MemoryArena mem_arena = 10; diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 9549b833..18b4cea4 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -64,6 +64,10 @@ def parse_int_array_from_str(ints_str): return [int(int_str) for int_str in ints_str.split(',')] +def parse_float_array_from_str(ints_str): + return [float(int_str) for int_str in ints_str.split(',')] + + def main(unused_args): if not os.path.isfile(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") @@ -105,12 +109,18 @@ def main(unused_args): input_node_names = FLAGS.input_node.split(',') input_node_shapes = FLAGS.input_shape.split(':') + if FLAGS.input_range: + input_node_ranges = FLAGS.input_range.split(':') + else: + input_node_ranges = [] if len(input_node_names) != len(input_node_shapes): raise Exception('input node count and shape count do not match.') for i in xrange(len(input_node_names)): input_node = cvt.NodeInfo() input_node.name = input_node_names[i] input_node.shape = parse_int_array_from_str(input_node_shapes[i]) + if len(input_node_ranges) > i: + input_node.range = parse_float_array_from_str(input_node_ranges[i]) option.add_input_node(input_node) output_node_names = FLAGS.output_node.split(',') @@ -276,6 +286,8 @@ def parse_args(): "--dsp_mode", type=int, default=0, help="dsp run mode, defalut=0") parser.add_argument( "--input_shape", type=str, default="", help="input shape.") + parser.add_argument( + "--input_range", type=str, default="", help="input range.") parser.add_argument( "--platform", type=str, default="tensorflow", help="tensorflow/caffe") parser.add_argument( diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index a46af224..dfe3158a 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -203,6 +203,8 @@ class TransformerRule(Enum): QUANTIZE_WEIGHTS = 25 TRANSFORM_LSTMCELL_ZEROSTATE = 26 TRANSFORM_BASIC_LSTMCELL = 27 + TRANSFORM_FAKE_QUANTIZE = 28 + CHECK_QUANTIZE_INFO = 29 class ConverterInterface(object): @@ -218,6 +220,7 @@ class NodeInfo(object): def __init__(self): self._name = None self._shape = [] + self._range = [-1.0, 1.0] @property def name(self): @@ -227,6 +230,10 @@ class NodeInfo(object): def shape(self): return self._shape + @property + def range(self): + return self._range + @name.setter def name(self, name): self._name = name @@ -235,6 +242,10 @@ class NodeInfo(object): def shape(self, shape): self._shape = shape + @range.setter + def range(self, range): + self._range = range + def __str__(self): return '%s %s' % (self._name, str(self._shape)) @@ -339,6 +350,7 @@ class ConverterOption(object): else: self._transformer_option = [ # Model structure related transformation + TransformerRule.TRANSFORM_FAKE_QUANTIZE, TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, @@ -368,15 +380,17 @@ class ConverterOption(object): # Transform finalization TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, # for quantization entropy calibration use - TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, TransformerRule.SORT_BY_EXECUTION, + # Need to be put after SORT_BY_EXECUTION + TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, ] if self._quantize: - self._transformer_option = self._transformer_option[:-1] + [ + self._transformer_option = self._transformer_option + [ + # need to be put after ADD_QUANTIZE_TENSOR_RANGE TransformerRule.QUANTIZE_NODES, - TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, TransformerRule.QUANTIZE_WEIGHTS, TransformerRule.SORT_BY_EXECUTION, + TransformerRule.CHECK_QUANTIZE_INFO, ] diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 24799631..09489f12 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -102,6 +102,7 @@ TFSupportedOps = [ 'Cast', 'ArgMax', 'Split', + 'FakeQuantWithMinMaxVars', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -205,6 +206,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Cast.name: self.convert_cast, TFOpType.ArgMax.name: self.convert_argmax, TFOpType.Split.name: self.convert_split, + TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -841,3 +843,21 @@ class TensorflowConverter(base_converter.ConverterInterface): num_split_arg.i = tf_op.get_attr('num_split') self._skip_tensor.add(tf_op.inputs[0].name) + + def convert_fake_quantize(self, tf_op): + op = self.convert_general_op(tf_op) + min_arg = op.arg.add() + min_arg.name = 'min' + min_arg.f = tf_op.inputs[1].eval() + max_arg = op.arg.add() + max_arg.name = 'max' + max_arg.f = tf_op.inputs[2].eval() + narrow_range_arg = op.arg.add() + narrow_range_arg.name = 'narrow_range' + narrow_range_arg.i = int(tf_op.get_attr('narrow_range')) + num_bits_arg = op.arg.add() + num_bits_arg.name = 'num_bits' + num_bits_arg.i = int(tf_op.get_attr('num_bits')) + + self._skip_tensor.add(tf_op.inputs[1].name) + self._skip_tensor.add(tf_op.inputs[2].name) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index ccd8d0ef..2302626d 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -47,6 +47,8 @@ class Transformer(base_converter.ConverterInterface): # Dependencies # (TRANSFORM_MATMUL_TO_FC, TRANSFORM_GLOBAL_CONV_TO_FC) -> RESHAPE_FC_WEIGHT # noqa self._registered_transformers = { + TransformerRule.TRANSFORM_FAKE_QUANTIZE: + self.transform_fake_quantize, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.TRANSFORM_GLOBAL_POOLING: self.transform_global_pooling, @@ -91,6 +93,8 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES: self.add_mace_input_and_output_nodes, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, + TransformerRule.CHECK_QUANTIZE_INFO: + self.check_quantize_info, } self._option = option @@ -774,16 +778,22 @@ class Transformer(base_converter.ConverterInterface): def transform_add_to_biasadd(self): net = self._model for op in net.op: - if op.type == 'Add' \ - and len(op.input) == 2 \ - and op.input[1] in self._consts \ - and len(self._consts[op.input[1]].dims) == 1: + if (op.type == 'Eltwise' + and ConverterUtil.get_arg(op, MaceKeyword.mace_element_type_str).i == EltwiseType.SUM.value # noqa + and len(op.input) == 2 + and op.input[1] in self._consts + and len(self._consts[op.input[1]].dims) == 1): print("Transform add to biasadd: %s(%s)" % (op.name, op.type)) op.type = MaceOp.BiasAdd.name return True return False + def replace_quantize_info(self, op, replace_op): + if len(replace_op.quantize_info) > 0: + del op.quantize_info[:] + op.quantize_info.extend(replace_op.quantize_info) + def fold_biasadd(self): net = self._model for op in net.op: @@ -799,6 +809,7 @@ class Transformer(base_converter.ConverterInterface): if consumer_op.type == MaceOp.BiasAdd.name: print("Fold biasadd: %s(%s)" % (op.name, op.type)) op.input.append(consumer_op.input[1]) + self.replace_quantize_info(op, consumer_op) self.safe_remove_node(consumer_op, op) return True @@ -886,6 +897,7 @@ class Transformer(base_converter.ConverterInterface): or arg.name == MaceKeyword.mace_activation_max_limit_str: # noqa op.arg.extend([arg]) + self.replace_quantize_info(op, consumer_op) self.safe_remove_node(consumer_op, op) return True @@ -1163,7 +1175,8 @@ class Transformer(base_converter.ConverterInterface): transposed_filter = set() transposed_deconv_filter = set() - if self._option.quantize: + if self._option.quantize and \ + self._option.device == DeviceType.CPU.value: print("Transpose filters to OHWI") if filter_format == FilterFormat.HWIO: transpose_order = [3, 0, 1, 2] @@ -1601,6 +1614,9 @@ class Transformer(base_converter.ConverterInterface): return False def quantize_nodes(self): + if not self._option.quantize: + return False + print("Add mace quantize and dequantize nodes") for op in self._model.op: @@ -1647,28 +1663,13 @@ class Transformer(base_converter.ConverterInterface): self._input_output_added = True - def add_quantize_tensor_range(self): - print("Add quantize tensor range") - net = self._model - range_file = self._option.quantize_range_file - if not range_file: - return - - with open(range_file) as f: - for line in f: - tensor_name, minmax = line.split("@@") - min_val, max_val = [float(i) for i in - minmax.strip().split(",")] - scale, zero = quantize_util.adjust_range(min_val, max_val, - non_zero=False) - activation_info = net.quantize_info.activation_info.add() - activation_info.tensor_name = tensor_name - activation_info.scale = scale - activation_info.zero_point = zero - self._quantize_activation_info[tensor_name] = activation_info + return False def quantize_tensor(self, tensor): """Assume biasadd has been already folded with convolution and fc""" + if not self._option.quantize: + return False + if tensor.data_type == mace_pb2.DT_FLOAT: ops = self._consumers.get(tensor.name, None) if len(ops) == 1 and ops[0].type in [MaceOp.Conv2D.name, @@ -1698,8 +1699,131 @@ class Transformer(base_converter.ConverterInterface): tensor.zero_point = quantized_tensor.zero self._quantized_tensor.update([tensor.name]) + return False + def quantize_weights(self): print("Quantize weights") net = self._model for tensor in net.tensors: self.quantize_tensor(tensor) + + return False + + def add_quantize_info(self, op, minval, maxval): + scale, zero = quantize_util.adjust_range(minval, maxval, + non_zero=False) + quantize_info = op.quantize_info.add() + quantize_info.minval = minval + quantize_info.maxval = maxval + quantize_info.scale = scale + quantize_info.zero_point = zero + + return quantize_info + + def transform_fake_quantize(self): + if not self._option.quantize: + return False + + # Quantize info from fixpoint fine tune + print("Transform fake quantize") + range_file = self._option.quantize_range_file + if range_file: + return + + net = self._model + for op in net.op: + if op.type == 'FakeQuantWithMinMaxVars': + producer_op = self._producer[op.input[0]] + minval = ConverterUtil.get_arg(op, 'min').f + maxval = ConverterUtil.get_arg(op, 'max').f + quantize_info = \ + self.add_quantize_info(producer_op, minval, maxval) + self._quantize_activation_info[op.input[0]] = quantize_info + op.type = MaceOp.Identity.name + + return False + + def add_quantize_tensor_range(self): + if not self._option.quantize: + return False + + # Quantize info from range statistics + print("Add quantize tensor range") + range_file = self._option.quantize_range_file + if range_file: + with open(range_file) as f: + for line in f: + tensor_name, minmax = line.split("@@") + min_val, max_val = [float(i) for i in + minmax.strip().split(",")] + scale, zero = quantize_util.adjust_range(min_val, max_val, + non_zero=False) + activation_info = mace_pb2.QuantizeActivationInfo() + activation_info.minval = min_val + activation_info.maxval = max_val + activation_info.scale = scale + activation_info.zero_point = zero + self._quantize_activation_info[tensor_name] = activation_info # noqa + + for op in self._model.op: + if op.name.find(MaceKeyword.mace_output_node_name) >= 0: + continue + for output in op.output: + mace_check(output in self._quantize_activation_info, + "%s does not have quantize activation info" + % op) + op.quantize_info.extend([ + self._quantize_activation_info[output] + for output in op.output]) + + print ("Add default quantize info for ops like Pooling, Softmax") + for op in self._model.op: + if op.type in [MaceOp.Pooling.name, + MaceOp.Squeeze.name, + MaceOp.Concat.name, + MaceOp.ResizeBilinear.name, + MaceOp.BatchToSpaceND.name, + MaceOp.SpaceToBatchND.name]: + del op.quantize_info[:] + producer_op = self._producer[op.input[0]] + quantize_info = op.quantize_info.add() + quantize_info.minval = producer_op.quantize_info[0].minval + quantize_info.maxval = producer_op.quantize_info[0].maxval + quantize_info.scale = producer_op.quantize_info[0].scale + quantize_info.zero_point = \ + producer_op.quantize_info[0].zero_point + self._quantize_activation_info[op.output[0]] = quantize_info + elif op.type == MaceOp.Softmax.name: + del op.quantize_info[:] + quantize_info = \ + self.add_quantize_info(op, 0.0, 1.0) + self._quantize_activation_info[op.output[0]] = quantize_info + + print ("Add default quantize info for input") + for input_node in self._option.input_nodes.values(): + if input_node.name not in self._quantize_activation_info: + print("Input range %s: %s" % (input_node.name, + str(input_node.range))) + scale, zero = quantize_util.adjust_range(input_node.range[0], + input_node.range[1], + non_zero=False) + quantize_info = mace_pb2.QuantizeActivationInfo() + quantize_info.minval = input_node.range[0] + quantize_info.maxval = input_node.range[1] + quantize_info.scale = scale + quantize_info.zero_point = zero + self._quantize_activation_info[input_node.name] = quantize_info + + return False + + def check_quantize_info(self): + if not self._option.quantize: + return False + + for op in self._model.op: + if (op.name.find(MaceKeyword.mace_input_node_name) == -1 + and op.name.find(MaceKeyword.mace_output_node_name) == -1 + and op.type != MaceOp.Quantize.name + and op.type != MaceOp.Dequantize.name): # noqa + mace_check(len(op.output) == len(op.quantize_info), + "missing quantize info: %s" % op) diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 52b1867b..53c56b98 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -80,9 +80,13 @@ class MemoryOptimizer(object): def op_need_optimize_memory(self, op): return True - def get_op_mem_block(self, op_type, output_shape): + def get_op_mem_block(self, op_type, output_shape, output_type): + data_type_size = 4 + if output_type == mace_pb2.DT_UINT8: + data_type_size = 1 return MemoryBlock(mace_pb2.CPU_BUFFER, - [reduce(operator.mul, output_shape, 1)]) + [reduce(operator.mul, output_shape, 1) * + data_type_size]) def mem_size(self, memory_block): return memory_block.block[0] @@ -143,9 +147,13 @@ class MemoryOptimizer(object): # make these ops reuse memory of input tensor mem_id = self.op_mem.get(op.input[0], -1) else: + output_type = mace_pb2.DT_FLOAT + if len(op.output_type) > i: + output_type = op.output_type[i] op_mem_block = self.get_op_mem_block( op.type, - op.output_shape[i].dims) + op.output_shape[i].dims, + output_type) mem_id = -1 if len(self.idle_mem) > 0: best_mem_add_size = sys.maxint @@ -221,7 +229,7 @@ class GPUMemoryOptimizer(MemoryOptimizer): return False return op.type != 'ImageToBuffer' - def get_op_mem_block(self, op_type, output_shape): + def get_op_mem_block(self, op_type, output_shape, output_type): if op_type == 'WinogradTransform' or op_type == 'MatMul': buffer_shape = list(output_shape) + [1] mem_block = MemoryBlock( diff --git a/mace/python/tools/model.jinja2 b/mace/python/tools/model.jinja2 index dcfe5434..9e8c521b 100644 --- a/mace/python/tools/model.jinja2 +++ b/mace/python/tools/model.jinja2 @@ -138,20 +138,6 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) { } {% endif %} -void AddQuantizeInfo(NetDef *net_def) { - MACE_LATENCY_LOGGER(1, "Add quantize info"); - (void) net_def; - - {% for i in range(net.quantize_info.activation_info|length) %} - mace::QuantizeActivationInfo *activation_info{{i}} = - net_def->mutable_quantize_info()->add_activation_info(); - activation_info{{i}}->set_tensor_name("{{net.quantize_info.activation_info[i].tensor_name}}"); - activation_info{{i}}->set_scale({{net.quantize_info.activation_info[i].scale}}); - activation_info{{i}}->set_zero_point({{net.quantize_info.activation_info[i].zero_point}}); - - {% endfor %} -} - } // namespace namespace {{tag}} { @@ -179,8 +165,6 @@ const std::shared_ptr CreateNet() { CreateOutputInfo(net_def.get()); {% endif %} - AddQuantizeInfo(net_def.get()); - return net_def; } diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2 index 6f682256..fc77b1e6 100644 --- a/mace/python/tools/operator.jinja2 +++ b/mace/python/tools/operator.jinja2 @@ -122,6 +122,16 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { {{ net.op[i].node_id }}, { {{ net.op[i].mem_id | join(', ') }} }); + + op->mutable_quantize_info()->Reserve({{ net.op[i].quantize_info | length }}); + {% for j in range(net.op[i].quantize_info|length) %} + auto quantize_info{{j}} = op->add_quantize_info(); + quantize_info{{j}}->set_scale({{ net.op[i].quantize_info[j].scale }}); + quantize_info{{j}}->set_zero_point({{ net.op[i].quantize_info[j].zero_point }}); + quantize_info{{j}}->set_minval({{ net.op[i].quantize_info[j].minval }}); + quantize_info{{j}}->set_maxval({{ net.op[i].quantize_info[j].maxval }}); + {% endfor %} + {% if runtime == 'dsp' %} op->set_padding({{ net.op[i].padding }}); {% if net.op[i].node_input | length > 0 %} diff --git a/tools/converter.py b/tools/converter.py index ebe87405..2a9aa30b 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -760,6 +760,7 @@ def convert_model(configs): runtime, model_name, ":".join(subgraphs[0][YAMLKeyword.input_shapes]), + ":".join(subgraphs[0][YAMLKeyword.input_ranges]), model_config[YAMLKeyword.nnlib_graph_mode], embed_model_data, model_config[YAMLKeyword.winograd], diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 00473c56..8a266ee8 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -551,6 +551,7 @@ def gen_model_code(model_codegen_dir, runtime, model_tag, input_shapes, + input_ranges, dsp_mode, embed_model_data, winograd, @@ -579,6 +580,7 @@ def gen_model_code(model_codegen_dir, "--template=%s" % "mace/python/tools", "--model_tag=%s" % model_tag, "--input_shape=%s" % input_shapes, + "--input_range=%s" % input_ranges, "--dsp_mode=%s" % dsp_mode, "--embed_model_data=%s" % embed_model_data, "--winograd=%s" % winograd, -- GitLab