From 3d1951765cc134d53b65b7a50f5fd97fe801122f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 30 Jul 2018 16:14:06 +0800 Subject: [PATCH] Init quantize project: 1. Quantize weights 2. Add quantize-dequantze nodes --- mace/core/net.cc | 20 ++ mace/core/workspace.cc | 133 +++++---- mace/kernels/conv_2d.h | 4 + mace/kernels/depthwise_conv2d.h | 4 + mace/kernels/quantize.h | 50 ++-- mace/proto/mace.proto | 13 + mace/python/tools/BUILD | 9 + mace/python/tools/converter.py | 22 +- .../tools/converter_tool/base_converter.py | 115 ++++++-- .../tools/converter_tool/transformer.py | 118 +++++++- mace/python/tools/model.jinja2 | 15 + mace/python/tools/model_saver.py | 26 -- mace/python/tools/quantization/__init__.py | 0 .../tools/quantization/quantize_stat.py | 51 ++++ .../tools/quantization/quantize_util.py | 108 +++++++ .../tools/quantization/quantize_util_test.py | 16 ++ mace/python/tools/tensor_source.jinja2 | 2 + mace/tools/quantization/BUILD | 18 ++ mace/tools/quantization/quantize_stat.cc | 264 ++++++++++++++++++ mace/utils/utils.h | 6 + tools/converter.py | 106 ++++++- tools/sh_commands.py | 4 + 22 files changed, 977 insertions(+), 127 deletions(-) create mode 100644 mace/python/tools/quantization/__init__.py create mode 100644 mace/python/tools/quantization/quantize_stat.py create mode 100644 mace/python/tools/quantization/quantize_util.py create mode 100644 mace/python/tools/quantization/quantize_util_test.py create mode 100644 mace/tools/quantization/BUILD create mode 100644 mace/tools/quantization/quantize_stat.cc diff --git a/mace/core/net.cc b/mace/core/net.cc index 259a9423..4922bc94 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -13,6 +13,8 @@ // limitations under the License. #include +#include +#include #include "mace/core/macros.h" #include "mace/core/net.h" @@ -125,6 +127,24 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { VLOG(3) << "Operator " << op->debug_def().name() << " has shape: " << MakeString(op->Output(0)->shape()); + + if (EnvEnabled("MACE_LOG_TENSOR_RANGE") && device_type_ == CPU) { + for (int i = 0; i < op->OutputSize(); ++i) { + int data_type = op->GetOptionalArg("T", static_cast(DT_FLOAT)); + if (data_type == static_cast(DT_FLOAT)) { + float max_v = std::numeric_limits::lowest(); + float min_v = std::numeric_limits::max(); + Tensor::MappingGuard guard(op->Output(i)); + const float *output_data = op->Output(i)->data(); + for (index_t j = 0; j < op->Output(i)->size(); ++j) { + max_v = std::max(max_v, output_data[j]); + min_v = std::min(min_v, output_data[j]); + } + LOG(INFO) << "Tensor range @@" << op->debug_def().output(i) + << "@@" << min_v << "," << max_v; + } + } + } } return MACE_SUCCESS; diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index fd083504..46966284 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -161,6 +161,8 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, const_tensor.data_type())); tensor->Reshape(dims); + tensor->SetScale(const_tensor.scale()); + tensor->SetZeroPoint(const_tensor.zero_point()); tensor_map_[const_tensor.name()] = std::move(tensor); } } @@ -170,37 +172,48 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, MaceStatus status = CreateOutputTensorBuffer(net_def, type); if (status != MaceStatus::MACE_SUCCESS) return status; } + + if (type == DeviceType::CPU && net_def.has_quantize_info()) { + for (const auto + &activation_info: net_def.quantize_info().activation_info()) { + MACE_CHECK(HasTensor(activation_info.tensor_name()), + "Quantize info exist for non-existed tensor", + activation_info.tensor_name()); + Tensor *tensor = GetTensor(activation_info.tensor_name()); + tensor->SetScale(activation_info.scale()); + tensor->SetZeroPoint(activation_info.zero_point()); + } + } + return MaceStatus::MACE_SUCCESS; } MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, DeviceType device_type) { - if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) { - return MaceStatus::MACE_SUCCESS; - } - DataType dtype = DataType::DT_INVALID; - // We use the data type of the first op with mem id, - // as CPU&GPU have consistent data type for each layer for now. - // As DSP may have different data output type for each op, - // we stick to the same concept. - for (auto &op : net_def.op()) { - // TODO(liuqi): refactor to add device_type to OperatorDef - const int op_device = - ProtoArgHelper::GetOptionalArg( - op, "device", static_cast(device_type)); - if (op_device == device_type && !op.mem_id().empty()) { - const DataType op_dtype = static_cast( + if (net_def.mem_arena().mem_block_size() > 0) { + // We use the data type of the first op with mem id, + // as CPU&GPU have consistent data type for each layer for now. + // As DSP may have different data output type for each op, + // we stick to the same concept. + for (auto &op : net_def.op()) { + // TODO(liuqi): refactor to add device_type to OperatorDef + const int op_device = ProtoArgHelper::GetOptionalArg( - op, "T", static_cast(DT_FLOAT))); - if (op_dtype != DataType::DT_INVALID) { - dtype = op_dtype; - // find first valid data type, break - break; + op, "device", static_cast(device_type)); + if (op_device == device_type && !op.mem_id().empty()) { + const DataType op_dtype = static_cast( + ProtoArgHelper::GetOptionalArg( + op, "T", static_cast(DT_FLOAT))); + if (op_dtype != DataType::DT_INVALID) { + dtype = op_dtype; + // find first valid data type, break + break; + } } } + MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid."); } - MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid."); // TODO(liyin): memory block should not have concept of type, but to be // consistent with gpu, all memory block use float/half as unit for (auto &mem_block : net_def.mem_arena().mem_block()) { @@ -239,36 +252,58 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, const int op_device = ProtoArgHelper::GetOptionalArg( op, "device", static_cast(device_type)); - if (op_device == device_type && !op.mem_id().empty() - && ShouldPreallocateMemoryForOp(op)) { - auto mem_ids = op.mem_id(); - int count = mem_ids.size(); - for (int i = 0; i < count; ++i) { - DataType output_type; - if (i < op.output_type_size()) { - output_type = op.output_type(i); - } else { - output_type = dtype; + if (op_device == device_type) { + if (!op.mem_id().empty() + && ShouldPreallocateMemoryForOp(op)) { + auto mem_ids = op.mem_id(); + int count = mem_ids.size(); + for (int i = 0; i < count; ++i) { + DataType output_type; + if (i < op.output_type_size()) { + output_type = op.output_type(i); + } else { + output_type = dtype; + } + std::unique_ptr tensor + (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), + output_type)); + tensor->SetSourceOpName(op.name()); + if (device_type == DeviceType::GPU) { + VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" + << " Mem: " << mem_ids[i] + << " Image shape: " + << dynamic_cast(tensor->UnderlyingBuffer()) + ->image_shape()[0] + << ", " + << dynamic_cast(tensor->UnderlyingBuffer()) + ->image_shape()[1]; + } else if (device_type == DeviceType::CPU) { + VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" + << " Mem: " << mem_ids[i] + << ", Buffer size: " << tensor->UnderlyingBuffer()->size(); + } + tensor_map_[op.output(i)] = std::move(tensor); } - std::unique_ptr tensor - (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), - output_type)); - tensor->SetSourceOpName(op.name()); - if (device_type == DeviceType::GPU) { - VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" - << " Mem: " << mem_ids[i] - << " Image shape: " - << dynamic_cast(tensor->UnderlyingBuffer()) - ->image_shape()[0] - << ", " - << dynamic_cast(tensor->UnderlyingBuffer()) - ->image_shape()[1]; - } else if (device_type == DeviceType::CPU) { - VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" - << " Mem: " << mem_ids[i] - << ", Buffer size: " << tensor->UnderlyingBuffer()->size(); + } else { + for (int i = 0; i < op.output().size(); ++i) { + MACE_CHECK( + op.output_type_size() == 0 + || op.output_size() + == op.output_type_size(), + "operator output size != operator output type size", + op.output_size(), + op.output_type_size()); + DataType output_type; + if (i < op.output_type_size()) { + output_type = op.output_type(i); + } else { + output_type = static_cast(ProtoArgHelper::GetOptionalArg( + op, "T", static_cast(DT_FLOAT))); + } + CreateTensor(op.output(i), + GetDeviceAllocator(device_type), + output_type); } - tensor_map_[op.output(i)] = std::move(tensor); } } } diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index ecbd6608..9654d967 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -488,6 +488,10 @@ struct Conv2dFunctor : Conv2dFunctorBase { const index_t extra_output_shape[4] = {batch, channels, extra_output_height, extra_output_width}; + // make host compiler happy + MACE_UNUSED(extra_input_shape); + MACE_UNUSED(extra_output_shape); + // decide which convolution function to call if (use_winograd) { transformed_input.Reshape(transformed_input_shape); diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index dd63be6f..14c83042 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -217,6 +217,10 @@ struct DepthwiseConv2dFunctor const index_t input_shape[4] = {batch, input_channels, input_height, input_width}; + // make host compiler happy + MACE_UNUSED(pad_hw); + MACE_UNUSED(input_shape); + if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { conv_func = [=](const float *input, float *output) { diff --git a/mace/kernels/quantize.h b/mace/kernels/quantize.h index 46f50c6a..975d7dbb 100644 --- a/mace/kernels/quantize.h +++ b/mace/kernels/quantize.h @@ -58,10 +58,11 @@ inline void AdjustRange(const float in_min_data, if (fabs(quantized_zero - quantized_zero_near_int) > kEps) { if (quantized_zero < quantized_zero_near_int || non_zero) { // keep out_max fixed, and move out_min - *scale = out_max / (quantized_max - quantized_zero_near_int); + *zero_point = static_cast(std::ceil(quantized_zero)); + *scale = out_max / (quantized_max - *zero_point); } else { // keep out_min fixed, and move out_max - *scale = -out_min / quantized_zero_near_int; + *scale = out_min / (quantized_min - *zero_point); } } } else if (out_min > -kEps) { @@ -96,6 +97,18 @@ inline void FindMinMax(const float *input, *max_val = max_v; } +template +inline void QuantizeWithScaleAndZeropoint(const float *input, + const index_t size, + float scale, + int32_t zero_point, + T *output) { + float recip_scale = 1 / scale; + for (int i = 0; i < size; ++i) { + output[i] = Saturate(roundf(zero_point + recip_scale * input[i])); + } +} + template inline void Quantize(const float *input, const index_t size, @@ -110,10 +123,7 @@ inline void Quantize(const float *input, AdjustRange(in_min_data, in_max_data, non_zero, scale, zero_point); - float recip_scale = 1 / *scale; - for (int i = 0; i < size; ++i) { - output[i] = Saturate(roundf(*zero_point + recip_scale * input[i])); - } + QuantizeWithScaleAndZeropoint(input, size, *scale, *zero_point, output); } template @@ -143,16 +153,24 @@ struct QuantizeFunctor { Tensor::MappingGuard output_guard(output); const float *input_data = input->data(); uint8_t *output_data = output->mutable_data(); - float scale; - int32_t zero_point; - Quantize(input_data, - input->size(), - non_zero, - output_data, - &scale, - &zero_point); - output->SetScale(scale); - output->SetZeroPoint(zero_point); + if (output->scale() > 0.f) { + QuantizeWithScaleAndZeropoint(input_data, + input->size(), + output->scale(), + output->zero_point(), + output_data); + } else { + float scale; + int32_t zero_point; + Quantize(input_data, + input->size(), + non_zero, + output_data, + &scale, + &zero_point); + output->SetScale(scale); + output->SetZeroPoint(zero_point); + } return MACE_SUCCESS; } diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 63115a86..4e4b6a07 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -34,6 +34,8 @@ message ConstTensor { optional string name = 5; optional int64 offset = 6; optional int64 data_size = 7; + optional float scale = 8; + optional int32 zero_point = 9; optional uint32 node_id = 100; } @@ -104,12 +106,23 @@ message OutputInfo { optional DataType data_type = 5 [default = DT_FLOAT]; } +message QuantizeActivationInfo { + optional string tensor_name = 1; + optional float scale = 2; + optional int32 zero_point = 3; +} + +message QuantizeInfo { + repeated QuantizeActivationInfo activation_info = 1; +} + message NetDef { optional string name = 1; repeated OperatorDef op = 2; optional string version = 3; repeated Argument arg = 4; repeated ConstTensor tensors = 5; + optional QuantizeInfo quantize_info = 6; // for mem optimization optional MemoryArena mem_arena = 10; diff --git a/mace/python/tools/BUILD b/mace/python/tools/BUILD index acc717d2..8b6ff799 100644 --- a/mace/python/tools/BUILD +++ b/mace/python/tools/BUILD @@ -1,3 +1,11 @@ +py_library( + name = "quantization_lib", + srcs = [ + "quantization/quantize_util.py", + ], + srcs_version = "PY2AND3", +) + py_library( name = "converter_lib", srcs = [ @@ -12,6 +20,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":quantization_lib", ":memory_optimizer", "//mace/proto:mace_py", "//third_party/caffe:caffe_py", diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 3fd856c7..9549b833 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -96,12 +96,12 @@ def main(unused_args): print ("runtime %s is not supported." % FLAGS.runtime) sys.exit(-1) + option = cvt.ConverterOption() if FLAGS.graph_optimize_options: - option = cvt.ConverterOption( - FLAGS.graph_optimize_options.split(',')) - else: - option = cvt.ConverterOption() + option.transformer_option = FLAGS.graph_optimize_options.split(',') option.winograd = FLAGS.winograd + option.quantize = FLAGS.quantize + option.quantize_range_file = FLAGS.quantize_range_file input_node_names = FLAGS.input_node.split(',') input_node_shapes = FLAGS.input_shape.split(':') @@ -119,6 +119,8 @@ def main(unused_args): output_node.name = output_node_names[i] option.add_output_node(output_node) + option.build() + print("Transform model to one that can better run on device") if FLAGS.runtime == 'dsp': mace_check(FLAGS.platform == 'tensorflow', @@ -297,6 +299,18 @@ def parse_args(): type=str, default="", help="graph optimize options") + parser.add_argument( + "--quantize", + type=str2bool, + nargs='?', + const=False, + default=False, + help="quantize model") + parser.add_argument( + "--quantize_range_file", + type=str, + default="", + help="file path of quantize range for each tensor") return parser.parse_known_args() diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 267fafeb..bc907aee 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -97,7 +97,6 @@ MaceSupportedOps = [ 'Proposal', 'Quantize', 'ReduceMean', - 'Requantize', 'Reshape', 'ResizeBilinear', 'Slice', @@ -189,6 +188,9 @@ class TransformerRule(Enum): ADD_IN_OUT_TENSOR_INFO = 20 ADD_MACE_INPUT_AND_OUTPUT_NODES = 21 UPDATE_FLOAT_OP_DATA_TYPE = 22 + QUANTIZE_NODES = 23 + ADD_QUANTIZE_TENSOR_RANGE = 24 + QUANTIZE_WEIGHTS = 25 class ConverterInterface(object): @@ -228,40 +230,15 @@ class NodeInfo(object): class ConverterOption(object): """A class for specifying options passed to converter tool""" - def __init__(self, transformers=None): + def __init__(self): self._input_nodes = {} self._output_nodes = {} self._data_type = mace_pb2.DT_FLOAT self._device = DeviceType.CPU.value self._winograd = 0 - if transformers: - self._transformer_option = [TransformerRule[transformer] - for transformer in transformers] - else: - self._transformer_option = [ - TransformerRule.REMOVE_IDENTITY_OP, - TransformerRule.TRANSFORM_GLOBAL_POOLING, - TransformerRule.FOLD_RESHAPE, - TransformerRule.TRANSFORM_MATMUL_TO_FC, - TransformerRule.FOLD_BATCHNORM, - TransformerRule.FOLD_CONV_AND_BN, - TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, - TransformerRule.TRANSFORM_GPU_WINOGRAD, - TransformerRule.TRANSFORM_ADD_TO_BIASADD, - TransformerRule.FOLD_BIASADD, - TransformerRule.FLATTEN_ATROUS_CONV, - TransformerRule.FOLD_ACTIVATION, - TransformerRule.TRANSPOSE_FILTERS, - TransformerRule.TRANSPOSE_DATA_FORMAT, - TransformerRule.ADD_IN_OUT_TENSOR_INFO, - TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, - TransformerRule.RESHAPE_FC_WEIGHT, - TransformerRule.TRANSFORM_BUFFER_IMAGE, - TransformerRule.ADD_DEVICE, - TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE, - TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, - TransformerRule.SORT_BY_EXECUTION, - ] + self._quantize = False + self._quantize_range_file = "" + self._transformer_option = None @property def input_nodes(self): @@ -283,6 +260,14 @@ class ConverterOption(object): def winograd(self): return self._winograd + @property + def quantize(self): + return self._quantize + + @property + def quantize_range_file(self): + return self._quantize_range_file + @property def transformer_option(self): return self._transformer_option @@ -315,6 +300,18 @@ class ConverterOption(object): def winograd(self, winograd): self._winograd = winograd + @quantize.setter + def quantize(self, quantize): + self._quantize = quantize + + @quantize_range_file.setter + def quantize_range_file(self, quantize_range_file): + self._quantize_range_file = quantize_range_file + + @transformer_option.setter + def transformer_option(self, transformer_option): + self._transformer_option = transformer_option + def disable_transpose_filters(self): if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option: self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS) @@ -323,6 +320,58 @@ class ConverterOption(object): if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option: self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS) + def build(self): + if self._transformer_option: + self._transformer_option = [TransformerRule[transformer] + for transformer in self._transformer_option] # noqa + else: + if not self._quantize: + self._transformer_option = [ + TransformerRule.REMOVE_IDENTITY_OP, + TransformerRule.TRANSFORM_GLOBAL_POOLING, + TransformerRule.FOLD_RESHAPE, + TransformerRule.TRANSFORM_MATMUL_TO_FC, + TransformerRule.FOLD_BATCHNORM, + TransformerRule.FOLD_CONV_AND_BN, + TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, + TransformerRule.TRANSFORM_GPU_WINOGRAD, + TransformerRule.TRANSFORM_ADD_TO_BIASADD, + TransformerRule.FOLD_BIASADD, + TransformerRule.FLATTEN_ATROUS_CONV, + TransformerRule.FOLD_ACTIVATION, + TransformerRule.TRANSPOSE_FILTERS, + TransformerRule.TRANSPOSE_DATA_FORMAT, + TransformerRule.ADD_IN_OUT_TENSOR_INFO, + TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, + TransformerRule.RESHAPE_FC_WEIGHT, + TransformerRule.TRANSFORM_BUFFER_IMAGE, + TransformerRule.ADD_DEVICE, + TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE, + TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, + TransformerRule.SORT_BY_EXECUTION, + ] + else: + self._transformer_option = [ + TransformerRule.REMOVE_IDENTITY_OP, + TransformerRule.TRANSFORM_GLOBAL_POOLING, + TransformerRule.FOLD_RESHAPE, + TransformerRule.TRANSFORM_MATMUL_TO_FC, + TransformerRule.FOLD_BATCHNORM, + TransformerRule.FOLD_CONV_AND_BN, + TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, + TransformerRule.TRANSFORM_GPU_WINOGRAD, + TransformerRule.TRANSFORM_ADD_TO_BIASADD, + TransformerRule.FOLD_BIASADD, + TransformerRule.FLATTEN_ATROUS_CONV, + TransformerRule.FOLD_ACTIVATION, + TransformerRule.ADD_IN_OUT_TENSOR_INFO, + TransformerRule.QUANTIZE_NODES, + TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, + TransformerRule.QUANTIZE_WEIGHTS, + TransformerRule.ADD_DEVICE, + TransformerRule.SORT_BY_EXECUTION, + ] + class ConverterUtil(object): @staticmethod @@ -338,6 +387,12 @@ class ConverterUtil(object): data_format_arg.name = MaceKeyword.mace_data_format_str data_format_arg.i = data_format.value + @staticmethod + def add_data_type_arg(op, data_type): + data_type_arg = op.arg.add() + data_type_arg.name = MaceKeyword.mace_op_data_type_str + data_type_arg.i = data_type + @staticmethod def data_format(op): arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index da8ecff1..30e4210a 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -31,7 +31,7 @@ from mace.python.tools.converter_tool.base_converter import TransformerRule from mace.python.tools.convert_util import calculate_image_shape from mace.python.tools.convert_util import mace_check from mace.python.tools.convert_util import OpenCLBufferType - +from mace.python.tools.quantization import quantize_util OPENCL_IMAGE_MAX_SIZE = 16384 @@ -73,6 +73,12 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight, TransformerRule.TRANSFORM_BUFFER_IMAGE: self.transform_buffer_image, + TransformerRule.QUANTIZE_NODES: + self.quantize_nodes, + TransformerRule.ADD_QUANTIZE_TENSOR_RANGE: + self.add_quantize_tensor_range, + TransformerRule.QUANTIZE_WEIGHTS: + self.quantize_weights, TransformerRule.ADD_DEVICE: self.add_device, TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE: @@ -93,6 +99,8 @@ class Transformer(base_converter.ConverterInterface): self._target_data_format = DataFormat.NHWC self._input_output_added = False self._opencl_max_image_size = [0, 0] + self._quantize_activation_info = {} + self._quantized_tensor = set() if self._option.device == DeviceType.CPU.value: self._target_data_format = DataFormat.NCHW @@ -854,6 +862,7 @@ class Transformer(base_converter.ConverterInterface): else: op.type = MaceOp.Identity.name + ConverterUtil.add_data_type_arg(op, mace_pb2.DT_FLOAT) ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) for output_node in self._option.output_nodes.values(): @@ -877,6 +886,7 @@ class Transformer(base_converter.ConverterInterface): ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) else: op.type = MaceOp.Identity.name + ConverterUtil.add_data_type_arg(op, mace_pb2.DT_FLOAT) self._input_output_added = True @@ -963,6 +973,7 @@ class Transformer(base_converter.ConverterInterface): arg = op_def.arg.add() arg.name = MaceKeyword.mace_mode arg.i = 0 + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT) tensor_shape = list(self._consts[input_name].dims) if input_type == OpenCLBufferType.WINOGRAD_FILTER: @@ -1054,6 +1065,7 @@ class Transformer(base_converter.ConverterInterface): arg.name = MaceKeyword.mace_buffer_type arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT) ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) for output_node in self._option.output_nodes.values(): @@ -1072,6 +1084,7 @@ class Transformer(base_converter.ConverterInterface): arg.name = MaceKeyword.mace_buffer_type arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT) ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) self._input_output_added = True @@ -1276,6 +1289,7 @@ class Transformer(base_converter.ConverterInterface): output_shape = op_def.output_shape.add() output_shape.dims.extend(input_node.shape) + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT) ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) for output_node in self._option.output_nodes.values(): @@ -1290,6 +1304,8 @@ class Transformer(base_converter.ConverterInterface): output_shape.dims.extend( self._producer[output_node.name].output_shape[0].dims) + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT) + def sort_by_execution(self): print("Sort by execution") net = self._model @@ -1311,3 +1327,103 @@ class Transformer(base_converter.ConverterInterface): print("%s (%s): %s" % (op.name, op.type, [ out_shape.dims for out_shape in op.output_shape])) return False + + def quantize_nodes(self): + print("Add mace quantize and dequantize nodes") + + for op in self._model.op: + data_type_arg = ConverterUtil.get_arg( + op, MaceKeyword.mace_op_data_type_str) + mace_check(data_type_arg, "Data type does not exist for %s(%s)" + % (op.name, op.type)) + if data_type_arg.i == mace_pb2.DT_FLOAT: + data_type_arg.i = mace_pb2.DT_UINT8 + else: + mace_check(False, + "Quantization only support float ops, " + "but get %s(%s)" + % (op.name, op.type)) + + for input_node in self._option.input_nodes.values(): + new_input_name = MaceKeyword.mace_input_node_name \ + + '_' + input_node.name + op_def = self._model.op.add() + op_def.name = self.normalize_op_name(input_node.name) + op_def.type = MaceOp.Quantize.name + op_def.input.extend([new_input_name]) + op_def.output.extend([input_node.name]) + output_shape = op_def.output_shape.add() + output_shape.dims.extend(input_node.shape) + + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) + ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) + + for output_node in self._option.output_nodes.values(): + output_name = MaceKeyword.mace_output_node_name \ + + '_' + output_node.name + op_def = self._model.op.add() + op_def.name = self.normalize_op_name(output_name) + op_def.type = MaceOp.Dequantize.name + op_def.input.extend([output_node.name]) + op_def.output.extend([output_name]) + output_shape = op_def.output_shape.add() + output_shape.dims.extend( + self._producer[output_node.name].output_shape[0].dims) + + ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) + + self._input_output_added = True + + def add_quantize_tensor_range(self): + print("Add quantize tensor range") + net = self._model + range_file = self._option.quantize_range_file + with open(range_file) as f: + for line in f: + tensor_name, minmax = line.split("@@") + min_val, max_val = [float(i) for i in + minmax.strip().split(",")] + scale, zero = quantize_util.adjust_range(min_val, max_val, + non_zero=False) + activation_info = net.quantize_info.activation_info.add() + activation_info.tensor_name = tensor_name + activation_info.scale = scale + activation_info.zero_point = zero + self._quantize_activation_info[tensor_name] = activation_info + + def quantize_tensor(self, tensor): + """Assume biasadd has been already folded with convolution and fc""" + if tensor.data_type == mace_pb2.DT_FLOAT: + ops = self._consumers.get(tensor.name, None) + if len(ops) == 1 and ops[0].type in [MaceOp.Conv2D.name, + MaceOp.Deconv2D.name, + MaceOp.DepthwiseConv2d.name, + MaceOp.FullyConnected.name] \ + and len(ops[0].input) >= 3 \ + and ops[0].input[2] == tensor.name: + conv_op = ops[0] + scale_input = self._quantize_activation_info[ + conv_op.input[0]].scale + if conv_op.input[1] not in self._quantized_tensor: + self.quantize_tensor(self._consts[conv_op.input[1]]) + scale_filter = self._consts[conv_op.input[1]].scale + scale = scale_input * scale_filter + + quantized_tensor = quantize_util.quantize_with_scale_and_zero( + tensor.float_data, scale, 0) + tensor.data_type = mace_pb2.DT_INT32 + else: + quantized_tensor = quantize_util.quantize(tensor.float_data) + tensor.data_type = mace_pb2.DT_UINT8 + + del tensor.float_data[:] + tensor.int32_data.extend(quantized_tensor.data) + tensor.scale = quantized_tensor.scale + tensor.zero_point = quantized_tensor.zero + self._quantized_tensor.update([tensor.name]) + + def quantize_weights(self): + print("Quantize weights") + net = self._model + for tensor in net.tensors: + self.quantize_tensor(tensor) diff --git a/mace/python/tools/model.jinja2 b/mace/python/tools/model.jinja2 index efb1c359..dcfe5434 100644 --- a/mace/python/tools/model.jinja2 +++ b/mace/python/tools/model.jinja2 @@ -138,6 +138,19 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) { } {% endif %} +void AddQuantizeInfo(NetDef *net_def) { + MACE_LATENCY_LOGGER(1, "Add quantize info"); + (void) net_def; + + {% for i in range(net.quantize_info.activation_info|length) %} + mace::QuantizeActivationInfo *activation_info{{i}} = + net_def->mutable_quantize_info()->add_activation_info(); + activation_info{{i}}->set_tensor_name("{{net.quantize_info.activation_info[i].tensor_name}}"); + activation_info{{i}}->set_scale({{net.quantize_info.activation_info[i].scale}}); + activation_info{{i}}->set_zero_point({{net.quantize_info.activation_info[i].zero_point}}); + + {% endfor %} +} } // namespace @@ -166,6 +179,8 @@ const std::shared_ptr CreateNet() { CreateOutputInfo(net_def.get()); {% endif %} + AddQuantizeInfo(net_def.get()); + return net_def; } diff --git a/mace/python/tools/model_saver.py b/mace/python/tools/model_saver.py index 0b849c76..f9583b30 100644 --- a/mace/python/tools/model_saver.py +++ b/mace/python/tools/model_saver.py @@ -98,30 +98,6 @@ def obfuscate_name(net_def): op.output[i] = in_out_map[op.output[i]] -def normalize_op_name(op_name): - idx = op_name.rfind(':') - if idx == -1: - return op_name - else: - return op_name[:idx] - - -def rename_tensor(net_def): - tensor_map = {} - for t in net_def.tensors: - if t.name not in tensor_map: - tensor_map[t.name] = "_" + normalize_op_name(t.name).replace("/", - "_") - t.name = tensor_map[t.name] - for op in net_def.op: - for i in range(len(op.input)): - if op.input[i] in tensor_map: - op.input[i] = tensor_map[op.input[i]] - for i in range(len(op.output)): - if op.output[i] in tensor_map: - op.output[i] = tensor_map[op.output[i]] - - def stringfy(value): return ', '.join('"{0}"'.format(w) for w in value) @@ -301,8 +277,6 @@ def save_model(net_def, model_checksum, weight_checksum, template_dir, winograd_conv, data_type, model_graph_format): if obfuscate: obfuscate_name(net_def) - else: - rename_tensor(net_def) output_dir = output_dir + '/' # update tensor type diff --git a/mace/python/tools/quantization/__init__.py b/mace/python/tools/quantization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mace/python/tools/quantization/quantize_stat.py b/mace/python/tools/quantization/quantize_stat.py new file mode 100644 index 00000000..6669cc28 --- /dev/null +++ b/mace/python/tools/quantization/quantize_stat.py @@ -0,0 +1,51 @@ +import argparse +import numpy as np + + +class QuantizeStat(object): + def __init__(self): + pass + + @staticmethod + def run(log_file, percentile): + res = {} + tensor_ranges = {} + with open(log_file) as log: + for line in log: + if line.find("Tensor range @@") != -1: + tensor_name, minmax = line.split("@@")[1:] + min_val, max_val = [float(i) for i in + minmax.strip().split(",")] + if tensor_name not in tensor_ranges: + tensor_ranges[tensor_name] = ([], []) + tensor_ranges[tensor_name][0].append(min_val) + tensor_ranges[tensor_name][1].append(max_val) + + for tensor_name in tensor_ranges: + tensor_min = np.percentile(tensor_ranges[tensor_name][0], + percentile) + tensor_max = np.percentile(tensor_ranges[tensor_name][1], + 100 - percentile) + assert tensor_min < tensor_max + res[tensor_name] = (tensor_min, tensor_max) + + return res + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--log_file", + type=str, + default="", + help="path of log file that records tensor range") + parser.add_argument( + "--percentile", + type=int, + default=5, + help="range percentile") + FLAGS, unparsed = parser.parse_known_args() + + res = QuantizeStat.run(FLAGS.log_file, FLAGS.percentile) + for tensor in res: + print("%s@@%f,%f" % (tensor, res[tensor][0], res[tensor][1])) diff --git a/mace/python/tools/quantization/quantize_util.py b/mace/python/tools/quantization/quantize_util.py new file mode 100644 index 00000000..8776cdde --- /dev/null +++ b/mace/python/tools/quantization/quantize_util.py @@ -0,0 +1,108 @@ +import numpy as np +import math + + +class QuantizedData(object): + def __init__(self): + self._data = None + self._scale = 0 + self._zero = 0 + + @property + def data(self): + return self._data + + @property + def scale(self): + return self._scale + + @property + def zero(self): + return self._zero + + @data.setter + def data(self, data): + self._data = data + + @scale.setter + def scale(self, scale): + self._scale = scale + + @zero.setter + def zero(self, zero): + self._zero = zero + + +def adjust_range(in_min, in_max, non_zero): + out_max = max(0.0, in_max) + out_min = min(0.0, in_min) + if non_zero: + out_min = min(out_min, in_min - (out_max - in_min) / 254.0) + scale = (out_max - out_min) / 255.0 + eps = 1e-6 + if out_min < -eps and out_max > eps: + zero = -out_min / scale + zero_int = int(round(zero)) + if abs(zero - zero_int) > eps: + if zero < zero_int or non_zero: + zero_int = int(math.ceil(zero)) + scale = out_max / (255.0 - zero_int) + else: + scale = -out_min / zero_int + elif out_min > -eps: + zero_int = 0 + else: + zero_int = 255 + + return scale, zero_int + + +def cal_multiplier_and_shift(scale): + """ + In order to use gemmlowp, we need to use gemmlowp-like transform + :param scale: + :return: multiplier, shift + """ + assert scale > 0, "scale should > 0, but get %s" % scale + assert scale < 1, "scale should < 1, but get %s" % scale + multiplier = scale + s = 0 + # make range [1/2, 1) + while multiplier < 0.5: + multiplier *= 2.0 + s += 1 + # convert scale to fixed-point + q = int(round(multiplier * (1 << 31))) + assert q <= (1 << 31) + if q == (1 << 31): + q /= 2 + s -= 1 + assert s >= 0 + return q, s + + +def quantize_with_scale_and_zero(data, scale, zero): + output = np.round(zero + data / scale).astype(int) + quantized_data = QuantizedData() + quantized_data.data = output + quantized_data.scale = scale + quantized_data.zero = zero + return quantized_data + + +def quantize(data): + np_data = np.array(data).astype(float) + in_min = np_data.min() + in_max = np_data.max() + scale, zero = adjust_range(in_min, in_max, non_zero=True) + output = np.clip((np.round(zero + data / scale).astype(int)), 0, 255) + + quantized_data = QuantizedData() + quantized_data.data = output + quantized_data.scale = scale + quantized_data.zero = zero + return quantized_data + + +def dequantize(quantized_data): + return quantized_data.scale * (quantized_data.data - quantized_data.zero) diff --git a/mace/python/tools/quantization/quantize_util_test.py b/mace/python/tools/quantization/quantize_util_test.py new file mode 100644 index 00000000..2c9c0a65 --- /dev/null +++ b/mace/python/tools/quantization/quantize_util_test.py @@ -0,0 +1,16 @@ +import unittest +import numpy as np +import quantize_util + + +class TestQuantize(unittest.TestCase): + + def test_quantize_dequantize(self): + test_input = np.random.rand(20, 30) * 5 + quantized_data = quantize_util.quantize(test_input) + dequantized_output = quantize_util.dequantize(quantized_data) + np.testing.assert_array_almost_equal(test_input, dequantized_output, 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/mace/python/tools/tensor_source.jinja2 b/mace/python/tools/tensor_source.jinja2 index ef59c8bc..66feee0e 100644 --- a/mace/python/tools/tensor_source.jinja2 +++ b/mace/python/tools/tensor_source.jinja2 @@ -32,6 +32,8 @@ void CreateTensor{{tensor_info.id}}(mace::ConstTensor *const_tensor) { {% endfor %} const_tensor->set_data_type(static_cast({{ tensor_info.data_type }})); const_tensor->set_node_id({{ tensor.node_id }}); + const_tensor->set_scale({{ tensor.scale }}); + const_tensor->set_zero_point({{ tensor.zero_point }}); } } // namespace {{tag}} diff --git a/mace/tools/quantization/BUILD b/mace/tools/quantization/BUILD new file mode 100644 index 00000000..345bde5d --- /dev/null +++ b/mace/tools/quantization/BUILD @@ -0,0 +1,18 @@ +# Quantize stat build + +cc_binary( + name = "quantize_stat", + srcs = ["quantize_stat.cc"], + copts = [ + "-Werror", + "-Wextra", + ], + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + "//external:gflags_nothreads", + "//mace/codegen:generated_mace_engine_factory", + "//mace/codegen:generated_models", + "//mace/libmace", + ], +) diff --git a/mace/tools/quantization/quantize_stat.cc b/mace/tools/quantization/quantize_stat.cc new file mode 100644 index 00000000..a05f42f7 --- /dev/null +++ b/mace/tools/quantization/quantize_stat.cc @@ -0,0 +1,264 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Usage: + * quantize_stat --model=mobi_mace.pb \ + * --input=input_node \ + * --output=output_node \ + * --input_shape=1,224,224,3 \ + * --output_shape=1,224,224,2 \ + * --input_dir=input_data_dir \ + * --output_file=mace.out \ + * --model_data_file=model_data.data + */ +#include +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "mace/public/mace.h" +#include "mace/public/mace_runtime.h" +#include "mace/utils/env_time.h" +#include "mace/utils/logging.h" +#include "mace/utils/utils.h" + +#ifdef MODEL_GRAPH_FORMAT_CODE +#include "mace/codegen/engine/mace_engine_factory.h" +#endif + +namespace mace { +namespace tools { +namespace quantization { + +namespace str_util { + +std::vector Split(const std::string &str, char delims) { + std::vector result; + if (str.empty()) { + result.push_back(""); + return result; + } + std::string tmp = str; + while (!tmp.empty()) { + size_t next_offset = tmp.find(delims); + result.push_back(tmp.substr(0, next_offset)); + if (next_offset == std::string::npos) { + break; + } else { + tmp = tmp.substr(next_offset + 1); + } + } + return result; +} + +} // namespace str_util + +void ParseShape(const std::string &str, std::vector *shape) { + std::string tmp = str; + while (!tmp.empty()) { + int dim = atoi(tmp.data()); + shape->push_back(dim); + size_t next_offset = tmp.find(","); + if (next_offset == std::string::npos) { + break; + } else { + tmp = tmp.substr(next_offset + 1); + } + } +} + +std::string FormatName(const std::string input) { + std::string res = input; + for (size_t i = 0; i < input.size(); ++i) { + if (!isalnum(res[i])) res[i] = '_'; + } + return res; +} + +DEFINE_string(model_name, + "", + "model name in yaml"); +DEFINE_string(input_node, + "input_node0,input_node1", + "input nodes, separated by comma"); +DEFINE_string(input_shape, + "1,224,224,3:1,1,1,10", + "input shapes, separated by colon and comma"); +DEFINE_string(output_node, + "output_node0,output_node1", + "output nodes, separated by comma"); +DEFINE_string(output_shape, + "1,224,224,2:1,1,1,10", + "output shapes, separated by colon and comma"); +DEFINE_string(input_dir, + "", + "input directory name"); +DEFINE_string(model_data_file, + "", + "model data file name, used when EMBED_MODEL_DATA set to 0 or 2"); +DEFINE_string(model_file, + "", + "model file name, used when load mace model in pb"); +DEFINE_int32(omp_num_threads, -1, "num of openmp threads"); + +bool RunModel(const std::string &model_name, + const std::vector &input_names, + const std::vector> &input_shapes, + const std::vector &output_names, + const std::vector> &output_shapes) { + MACE_RETURN_IF_ERROR(mace::SetOpenMPThreadPolicy( + FLAGS_omp_num_threads, CPUAffinityPolicy::AFFINITY_NONE)); + + std::vector model_pb_data; + if (FLAGS_model_file != "") { + if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) { + LOG(FATAL) << "Failed to read file: " << FLAGS_model_file; + } + } + + std::shared_ptr engine; + + // Create Engine +#ifdef MODEL_GRAPH_FORMAT_CODE + MACE_RETURN_IF_ERROR( + CreateMaceEngineFromCode(model_name, + FLAGS_model_data_file, + input_names, + output_names, + DeviceType::CPU, + &engine)); +#else + (void) (model_name); + MACE_RETURN_IF_ERROR( + CreateMaceEngineFromProto(model_pb_data, + FLAGS_model_data_file, + input_names, + output_names, + DeviceType::CPU, + &engine)); +#endif + + const size_t input_count = input_names.size(); + const size_t output_count = output_names.size(); + + std::map inputs; + std::map outputs; + std::map inputs_size; + for (size_t i = 0; i < input_count; ++i) { + int64_t input_size = + std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, + std::multiplies()); + inputs_size[input_names[i]] = input_size; + auto buffer_in = std::shared_ptr(new float[input_size], + std::default_delete()); + inputs[input_names[i]] = mace::MaceTensor(input_shapes[i], buffer_in); + } + + for (size_t i = 0; i < output_count; ++i) { + int64_t output_size = + std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, + std::multiplies()); + auto buffer_out = std::shared_ptr(new float[output_size], + std::default_delete()); + outputs[output_names[i]] = mace::MaceTensor(output_shapes[i], buffer_out); + } + + DIR *dir_parent; + struct dirent *entry; + dir_parent = opendir(FLAGS_input_dir.c_str()); + if (dir_parent) { + while ((entry = readdir(dir_parent))) { + std::string file_name = std::string(entry->d_name); + std::string prefix = FormatName(input_names[0]); + if (file_name.find(prefix) == 0) { + std::string suffix = file_name.substr(prefix.size()); + + for (size_t i = 0; i < input_count; ++i) { + file_name = FLAGS_input_dir + "/" + FormatName(input_names[i]) + + suffix; + std::ifstream in_file(file_name, std::ios::in | std::ios::binary); + VLOG(2) << "Read " << file_name; + if (in_file.is_open()) { + in_file.read(reinterpret_cast( + inputs[input_names[i]].data().get()), + inputs_size[input_names[i]] * sizeof(float)); + in_file.close(); + } else { + LOG(INFO) << "Open input file failed"; + return -1; + } + } + MACE_RETURN_IF_ERROR(engine->Run(inputs, &outputs)); + } + } + + closedir(dir_parent); + } else { + LOG(ERROR) << "Directory " << FLAGS_input_dir << " does not exist."; + } + return true; +} + +int Main(int argc, char **argv) { + std::string usage = "quantize stat model\nusage: " + std::string(argv[0]) + + " [flags]"; + gflags::SetUsageMessage(usage); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + LOG(INFO) << "model name: " << FLAGS_model_name; + LOG(INFO) << "mace version: " << MaceVersion(); + LOG(INFO) << "input node: " << FLAGS_input_node; + LOG(INFO) << "input shape: " << FLAGS_input_shape; + LOG(INFO) << "output node: " << FLAGS_output_node; + LOG(INFO) << "output shape: " << FLAGS_output_shape; + LOG(INFO) << "input_dir: " << FLAGS_input_dir; + LOG(INFO) << "model_data_file: " << FLAGS_model_data_file; + LOG(INFO) << "model_file: " << FLAGS_model_file; + LOG(INFO) << "omp_num_threads: " << FLAGS_omp_num_threads; + + std::vector input_names = str_util::Split(FLAGS_input_node, ','); + std::vector output_names = + str_util::Split(FLAGS_output_node, ','); + std::vector input_shapes = + str_util::Split(FLAGS_input_shape, ':'); + std::vector output_shapes = + str_util::Split(FLAGS_output_shape, ':'); + + const size_t input_count = input_shapes.size(); + const size_t output_count = output_shapes.size(); + std::vector> input_shape_vec(input_count); + std::vector> output_shape_vec(output_count); + for (size_t i = 0; i < input_count; ++i) { + ParseShape(input_shapes[i], &input_shape_vec[i]); + } + for (size_t i = 0; i < output_count; ++i) { + ParseShape(output_shapes[i], &output_shape_vec[i]); + } + + return RunModel(FLAGS_model_name, input_names, input_shape_vec, + output_names, output_shape_vec); +} + +} // namespace quantization +} // namespace tools +} // namespace mace + +int main(int argc, char **argv) { + mace::tools::quantization::Main(argc, argv); +} diff --git a/mace/utils/utils.h b/mace/utils/utils.h index c153e4ba..c9e561b5 100644 --- a/mace/utils/utils.h +++ b/mace/utils/utils.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -162,5 +163,10 @@ std::vector MapKeys(const std::map &data) { return keys; } +inline bool EnvEnabled(std::string env_name) { + char *env = getenv(env_name.c_str()); + return !(!env || env[0] == 0 || env[0] == '0'); +} + } // namespace mace #endif // MACE_UTILS_UTILS_H_ diff --git a/tools/converter.py b/tools/converter.py index 0b3fca51..efd81168 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -70,6 +70,7 @@ MACE_RUN_STATIC_NAME = "mace_run_static" MACE_RUN_DYNAMIC_NAME = "mace_run_dynamic" MACE_RUN_STATIC_TARGET = "//mace/tools/validation:" + MACE_RUN_STATIC_NAME MACE_RUN_DYNAMIC_TARGET = "//mace/tools/validation:" + MACE_RUN_DYNAMIC_NAME +QUANTIZE_STAT_TARGET = "//mace/tools/quantization:quantize_stat" EXAMPLE_STATIC_NAME = "example_static" EXAMPLE_DYNAMIC_NAME = "example_dynamic" EXAMPLE_STATIC_TARGET = "//mace/examples/cli:" + EXAMPLE_STATIC_NAME @@ -185,6 +186,8 @@ class YAMLKeyword(object): nnlib_graph_mode = 'nnlib_graph_mode' obfuscate = 'obfuscate' winograd = 'winograd' + quantize = 'quantize' + quantize_range_file = 'quantize_range_file' validation_inputs_data = 'validation_inputs_data' graph_optimize_options = 'graph_optimize_options' # internal use for now @@ -459,7 +462,8 @@ def format_model_config(flags): for key in [YAMLKeyword.limit_opencl_kernel_time, YAMLKeyword.nnlib_graph_mode, YAMLKeyword.obfuscate, - YAMLKeyword.winograd]: + YAMLKeyword.winograd, + YAMLKeyword.quantize]: value = model_config.get(key, "") if value == "": model_config[key] = 0 @@ -705,6 +709,8 @@ def convert_model(configs): model_config[YAMLKeyword.nnlib_graph_mode], embed_model_data, model_config[YAMLKeyword.winograd], + model_config[YAMLKeyword.quantize], + model_config.get(YAMLKeyword.quantize_range_file, ""), model_config[YAMLKeyword.obfuscate], configs[YAMLKeyword.model_graph_format], data_type, @@ -871,6 +877,37 @@ def build_mace_run(configs, target_abi, enable_openmp, address_sanitizer, mace_lib_type == MACELibType.dynamic) +def build_quantize_stat(configs): + library_name = configs[YAMLKeyword.library_name] + + build_tmp_binary_dir = get_build_binary_dir(library_name, ABIType.host) + if os.path.exists(build_tmp_binary_dir): + sh.rm("-rf", build_tmp_binary_dir) + os.makedirs(build_tmp_binary_dir) + + quantize_stat_target = QUANTIZE_STAT_TARGET + build_arg = "" + print (configs[YAMLKeyword.model_graph_format]) + if configs[YAMLKeyword.model_graph_format] == ModelFormat.code: + mace_check(os.path.exists(ENGINE_CODEGEN_DIR), + ModuleName.RUN, + "You should convert model first.") + build_arg = "--per_file_copt=mace/tools/quantization/quantize_stat.cc@-DMODEL_GRAPH_FORMAT_CODE" # noqa + + sh_commands.bazel_build( + quantize_stat_target, + abi=ABIType.host, + enable_openmp=True, + extra_args=build_arg + ) + + quantize_stat_filepath = build_tmp_binary_dir + "/quantize_stat" + if os.path.exists(quantize_stat_filepath): + sh.rm("-rf", quantize_stat_filepath) + sh.cp("-f", "bazel-bin/mace/tools/quantization/quantize_stat", + build_tmp_binary_dir) + + def build_example(configs, target_abi, enable_openmp, mace_lib_type): library_name = configs[YAMLKeyword.library_name] hexagon_mode = get_hexagon_mode(configs) @@ -1196,6 +1233,59 @@ def run_specific_target(flags, configs, target_abi, opencl_parameter_bin_path) +def run_quantize_stat(flags, configs): + library_name = configs[YAMLKeyword.library_name] + build_tmp_binary_dir = get_build_binary_dir(library_name, ABIType.host) + + for model_name in configs[YAMLKeyword.models]: + check_model_converted(library_name, model_name, + configs[YAMLKeyword.model_graph_format], + configs[YAMLKeyword.model_data_format], + ABIType.host) + MaceLogger.header( + StringFormatter.block( + "Run model %s on %s" % (model_name, ABIType.host))) + + model_config = configs[YAMLKeyword.models][model_name] + subgraphs = model_config[YAMLKeyword.subgraphs] + + _, _, mace_model_dir = \ + get_build_model_dirs(library_name, model_name, ABIType.host, + None, None, + model_config[YAMLKeyword.model_file_path]) + + mace_model_path = "" + if configs[YAMLKeyword.model_graph_format] == ModelFormat.file: + mace_model_path = "%s/%s.pb" % (mace_model_dir, model_name) + + p = subprocess.Popen( + [ + "env", + "MACE_CPP_MIN_VLOG_LEVEL=%s" % flags.vlog_level, + "MACE_LOG_TENSOR_RANGE=1", + "%s/%s" % (build_tmp_binary_dir, "quantize_stat"), + "--model_name=%s" % model_name, + "--input_node=%s" % ",".join( + subgraphs[0][YAMLKeyword.input_tensors]), + "--output_node=%s" % ",".join( + subgraphs[0][YAMLKeyword.output_tensors]), + "--input_shape=%s" % ":".join( + subgraphs[0][YAMLKeyword.input_shapes]), + "--output_shape=%s" % ":".join( + subgraphs[0][YAMLKeyword.output_shapes]), + "--input_dir=%s" % flags.input_dir, + "--model_data_file=%s/%s.data" % (mace_model_dir, model_name), + "--omp_num_threads=%s" % flags.omp_num_threads, + "--model_file=%s" % mace_model_path, + ], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + out, err = p.communicate() + stdout = err + out + print stdout + print("Running finished!\n") + + def print_package_summary(package_path): title = "Library" header = ["key", "value"] @@ -1216,6 +1306,11 @@ def run_mace(flags): clear_build_dirs(configs[YAMLKeyword.library_name]) + if flags.quantize_stat: + build_quantize_stat(configs) + run_quantize_stat(flags, configs) + return + target_socs = configs[YAMLKeyword.target_socs] if not target_socs or ALL_SOC_TAG in target_socs: target_socs = sh_commands.adb_get_all_socs() @@ -1582,6 +1677,15 @@ def parse_args(): "--example", action="store_true", help="whether to run example.") + run.add_argument( + "--quantize_stat", + action="store_true", + help="whether to stat quantization range.") + run.add_argument( + "--input_dir", + type=str, + default="", + help="quantize stat input dir.") benchmark = subparsers.add_parser( 'benchmark', parents=[all_type_parent_parser, run_bm_parent_parser], diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 0340f3d5..1beb8fe8 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -490,6 +490,8 @@ def gen_model_code(model_codegen_dir, dsp_mode, embed_model_data, winograd, + quantize, + quantize_range_file, obfuscate, model_graph_format, data_type, @@ -516,6 +518,8 @@ def gen_model_code(model_codegen_dir, "--dsp_mode=%s" % dsp_mode, "--embed_model_data=%s" % embed_model_data, "--winograd=%s" % winograd, + "--quantize=%s" % quantize, + "--quantize_range_file=%s" % quantize_range_file, "--obfuscate=%s" % obfuscate, "--output_dir=%s" % model_codegen_dir, "--model_graph_format=%s" % model_graph_format, -- GitLab