diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index 7bd27d9640643a729b9f6b869f71c8480bf8cb56..54158d9c7158683eaa4a1b99e69b8613b16dc2d4 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -504,3 +504,25 @@ which will reduce the library size significantly. the final binary just link the } } // namespace mace + +Reduce Model Size +------------------- +Model file size can be a bottleneck for the deployment of neural networks on mobile devices, +so MACE provides several ways to reduce the model size with no or little performance or accuracy degradation. + +**1. Save model weights in half-precision floating point format** + +The default data type of a regular model is float (32bit). To reduce the model weights size, +half (16bit) can be used to reduce it by half with negligible accuracy degradation. + +For CPU, ``data_type`` can be specified as ``fp16_fp32`` in the deployment file to save the weights in half and actual inference in float. + +For GPU, ``fp16_fp32`` is default. The ops in GPU take half as inputs and outputs while kernel execution in float. + +**2. Save model weights in quantized fixed point format** + +Weights of convolutional (excluding depthwise) and fully connected layers take up a major part of model size. +These weights can be quantized to 8bit to reduce the size to a quarter, whereas the accuracy usually decreases only by 1%-3%. +For example, the top-1 accuracy of MobileNetV1 after quantization of weights is 68.2% on the ImageNet validation set. +``quantize_large_weights`` can be specified as 1 in the deployment file to save these weights in 8bit and actual inference in float. +It can be used for both CPU and GPU. diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index f1740765eee32b43ae1af78011b9dbb5b8460c01..a70fe3afa4b2523b14a0e94865da2e01f8f9a404 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -104,9 +104,9 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, if (model_data_size > 0) { bool is_quantize_model = IsQuantizedModel(net_def); - diffused_buffer_ = (device_type == DeviceType::CPU && - (HasHalfTensor(net_def) || - (!is_quantize_model && HasQuantizedTensor(net_def)))); + diffused_buffer_ = + (device_type == DeviceType::CPU && HasHalfTensor(net_def)) || + (!is_quantize_model && HasQuantizedTensor(net_def)); #ifdef MACE_ENABLE_OPENCL diffused_buffer_ = diffused_buffer_ || (device_type == DeviceType::GPU && device->gpu_runtime()->opencl_runtime()->GetDeviceMaxMemAllocSize() <= @@ -125,8 +125,9 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, } DataType dst_data_type = const_tensor.data_type(); - if (device_type == DeviceType::CPU && - const_tensor.data_type() == DataType::DT_HALF) { + if ((device_type == DeviceType::CPU && + const_tensor.data_type() == DataType::DT_HALF) || + (!is_quantize_model && const_tensor.quantized())) { dst_data_type = DataType::DT_FLOAT; } @@ -147,8 +148,8 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, ") should <= ", model_data_size); - if (device_type == DeviceType::CPU) { - if (const_tensor.data_type() == DataType::DT_HALF) { + if (device_type == DeviceType::CPU && + const_tensor.data_type() == DataType::DT_HALF) { // uncompress the weights of fp16 auto org_data = reinterpret_cast( model_data + const_tensor.offset()); @@ -156,25 +157,19 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, for (int i = 0; i < const_tensor.data_size(); ++i) { dst_data[i] = half_float::half_cast(org_data[i]); } - } else if (!is_quantize_model && const_tensor.quantized()) { - // uncompress the weights of uint8 - std::unique_ptr dequantized_tensor(new Tensor(true)); - dequantized_tensor->Resize(dims); - auto quantized_data = reinterpret_cast( - model_data + const_tensor.offset()); - auto dequantized_data = tensor->mutable_data(); - QuantizeUtil - quantize_util(&device->cpu_runtime()->thread_pool()); - quantize_util.Dequantize(quantized_data, - tensor->size(), - const_tensor.scale(), - const_tensor.zero_point(), - dequantized_data); - } else { - tensor->CopyBytes(model_data + const_tensor.offset(), - const_tensor.data_size() * - GetEnumTypeSize(const_tensor.data_type())); - } + } else if (!is_quantize_model && const_tensor.quantized()) { + // uncompress the weights of uint8 + Tensor::MappingGuard guard(tensor.get()); + auto quantized_data = reinterpret_cast( + model_data + const_tensor.offset()); + auto dequantized_data = tensor->mutable_data(); + QuantizeUtil + quantize_util(&device->cpu_runtime()->thread_pool()); + quantize_util.Dequantize(quantized_data, + tensor->size(), + const_tensor.scale(), + const_tensor.zero_point(), + dequantized_data); } else { tensor->CopyBytes(model_data + const_tensor.offset(), const_tensor.data_size() * diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 23456f55365edaa41a6a996069511a04446fea1b..d623605204405f6013ca18e39d40125cc33cecc9 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -142,6 +142,7 @@ def main(unused_args): option.transformer_option = FLAGS.graph_optimize_options.split(',') option.winograd = FLAGS.winograd option.quantize = FLAGS.quantize + option.quantize_large_weights = FLAGS.quantize_large_weights option.quantize_range_file = FLAGS.quantize_range_file option.change_concat_ranges = FLAGS.change_concat_ranges option.cl_mem_type = FLAGS.cl_mem_type @@ -389,6 +390,13 @@ def parse_args(): const=False, default=False, help="quantize model") + parser.add_argument( + "--quantize_large_weights", + type=str2bool, + nargs='?', + const=False, + default=False, + help="quantize large weights for compression") parser.add_argument( "--quantize_range_file", type=str, diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 5f2dbaaacc2db16ad54707c4876ca6efe50d453d..750a6cd42530d53efc0497f5b9a305881bc1ebf4 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -318,6 +318,7 @@ class TransformerRule(Enum): QUANTIZE_SPECIFIC_OPS_ONLY = 40 FP16_MATMUL_WEIGHT = 41 FP16_GATHER_WEIGHT = 42 + QUANTIZE_LARGE_WEIGHTS = 43 class ConverterInterface(object): @@ -392,6 +393,7 @@ class ConverterOption(object): self._device = DeviceType.CPU.value self._winograd = 0 self._quantize = False + self._quantize_large_weights = False self._quantize_range_file = "" self._change_concat_ranges = False self._transformer_option = None @@ -425,6 +427,10 @@ class ConverterOption(object): def quantize(self): return self._quantize + @property + def quantize_large_weights(self): + return self._quantize_large_weights + @property def change_concat_ranges(self): return self._change_concat_ranges @@ -481,6 +487,10 @@ class ConverterOption(object): def quantize(self, quantize): self._quantize = quantize + @quantize_large_weights.setter + def quantize_large_weights(self, quantize_large_weights): + self._quantize_large_weights = quantize_large_weights + @quantize_range_file.setter def quantize_range_file(self, quantize_range_file): self._quantize_range_file = quantize_range_file @@ -556,6 +566,10 @@ class ConverterOption(object): # Need to be put after SORT_BY_EXECUTION TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, ] + if self.quantize_large_weights: + self._transformer_option = self._transformer_option + [ + TransformerRule.QUANTIZE_LARGE_WEIGHTS + ] if self._quantize: self._transformer_option = self._transformer_option + [ # need to be put after ADD_QUANTIZE_TENSOR_RANGE diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 4f5864ec78c598d575f433cdb7a6454f8833ab2b..a6765627e5f0844f6354ed4e6122f9adf99d7c1b 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -110,6 +110,8 @@ class Transformer(base_converter.ConverterInterface): self.fp16_matmul_weight, TransformerRule.FP16_GATHER_WEIGHT: self.fp16_gather_weight, + TransformerRule.QUANTIZE_LARGE_WEIGHTS: + self.quantize_large_weights, } self._option = option @@ -1625,6 +1627,35 @@ class Transformer(base_converter.ConverterInterface): return False + def quantize_large_tensor(self, tensor): + if tensor.data_type == mace_pb2.DT_FLOAT: + ops = self._consumers.get(tensor.name, None) + if ops is not None and len(ops) == 1: + if ops[0].type in [MaceOp.Conv2D.name, + MaceOp.FullyConnected.name]: + quantized_tensor = \ + quantize_util.quantize(tensor.float_data, + self._option.device, + False) + tensor.data_type = mace_pb2.DT_UINT8 + + del tensor.float_data[:] + tensor.int32_data.extend(quantized_tensor.data) + tensor.scale = quantized_tensor.scale + tensor.zero_point = quantized_tensor.zero + tensor.minval = quantized_tensor.minval + tensor.maxval = quantized_tensor.maxval + tensor.quantized = True + self._quantized_tensor.update([tensor.name]) + + def quantize_large_weights(self): + print("Quantize large weights") + net = self._model + for tensor in net.tensors: + self.quantize_large_tensor(tensor) + + return False + def add_quantize_info(self, op, minval, maxval): scale, zero, minval, maxval = \ quantize_util.adjust_range(minval, maxval, self._option.device, diff --git a/tools/common.py b/tools/common.py index 5137327bcddc56129de4bda3bcb8d1e647ca2503..a7a3cfdb882c662f25aa6006295b585ed655424c 100644 --- a/tools/common.py +++ b/tools/common.py @@ -132,6 +132,9 @@ class DeviceType(object): HTA = 'HTA' APU = 'APU' + # for validation threshold + QUANTIZE = 'QUANTIZE' + class DataFormat(object): NONE = "NONE" @@ -408,6 +411,7 @@ class YAMLKeyword(object): obfuscate = 'obfuscate' winograd = 'winograd' quantize = 'quantize' + quantize_large_weights = 'quantize_large_weights' quantize_range_file = 'quantize_range_file' change_concat_ranges = 'change_concat_ranges' validation_inputs_data = 'validation_inputs_data' diff --git a/tools/converter.py b/tools/converter.py index 4c1ca1c808f6078a5594923e688e99b8cbae4bac..74cb28b882d86f3464b17210729e6574a293b258 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -118,8 +118,7 @@ class DefaultValues(object): class ValidationThreshold(object): cpu_threshold = 0.999, gpu_threshold = 0.995, - hexagon_threshold = 0.930, - cpu_quantize_threshold = 0.980, + quantize_threshold = 0.980, CPP_KEYWORDS = [ @@ -501,12 +500,9 @@ def format_model_config(flags): threshold_dict = { DeviceType.CPU: ValidationThreshold.cpu_threshold, DeviceType.GPU: ValidationThreshold.gpu_threshold, - DeviceType.HEXAGON + "_QUANTIZE": - ValidationThreshold.hexagon_threshold, - DeviceType.HTA + "_QUANTIZE": - ValidationThreshold.hexagon_threshold, - DeviceType.CPU + "_QUANTIZE": - ValidationThreshold.cpu_quantize_threshold, + DeviceType.HEXAGON: ValidationThreshold.quantize_threshold, + DeviceType.HTA: ValidationThreshold.quantize_threshold, + DeviceType.QUANTIZE: ValidationThreshold.quantize_threshold, } for k, v in six.iteritems(validation_threshold): if k.upper() == 'DSP': @@ -515,7 +511,7 @@ def format_model_config(flags): DeviceType.GPU, DeviceType.HEXAGON, DeviceType.HTA, - DeviceType.CPU + "_QUANTIZE"): + DeviceType.QUANTIZE): raise argparse.ArgumentTypeError( 'Unsupported validation threshold runtime: %s' % k) threshold_dict[k.upper()] = v @@ -566,11 +562,18 @@ def format_model_config(flags): YAMLKeyword.obfuscate, YAMLKeyword.winograd, YAMLKeyword.quantize, + YAMLKeyword.quantize_large_weights, YAMLKeyword.change_concat_ranges]: value = model_config.get(key, "") if value == "": model_config[key] = 0 + mace_check(model_config[YAMLKeyword.quantize] == 0 or + model_config[YAMLKeyword.quantize_large_weights] == 0, + ModuleName.YAML_CONFIG, + "quantize and quantize_large_weights should not be set to 1" + " at the same time.") + mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters, ModuleName.YAML_CONFIG, "'winograd' parameters must be in " @@ -773,6 +776,7 @@ def convert_model(configs, cl_mem_type): embed_model_data, model_config[YAMLKeyword.winograd], model_config[YAMLKeyword.quantize], + model_config[YAMLKeyword.quantize_large_weights], quantize_range_file_path, model_config[YAMLKeyword.change_concat_ranges], model_config[YAMLKeyword.obfuscate], diff --git a/tools/device.py b/tools/device.py index 152ee4a28d0a9a9809c8cc49b0ab1cfc22603151..34b4adbb0311c1f2a4616c40982e25c07de9d63c 100644 --- a/tools/device.py +++ b/tools/device.py @@ -730,8 +730,11 @@ class DeviceWrapper: model_config[ YAMLKeyword.weight_sha256_checksum]) validate_type = device_type - if model_config[YAMLKeyword.quantize] == 1: - validate_type = device_type + '_QUANTIZE' + if device_type in [DeviceType.CPU, + DeviceType.GPU] and \ + (model_config[YAMLKeyword.quantize] == 1 or + model_config[YAMLKeyword.quantize_large_weights] == 1): # noqa + validate_type = DeviceType.QUANTIZE dockerfile_path, docker_image_tag = \ get_dockerfile_info( diff --git a/tools/sh_commands.py b/tools/sh_commands.py index 1380a4921dbd8dff3a968d61f092effc1d153724..172b8d44fb9c17914e719bec9fc811ac9820b3aa 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -499,6 +499,7 @@ def gen_model_code(model_codegen_dir, embed_model_data, winograd, quantize, + quantize_large_weights, quantize_range_file, change_concat_ranges, obfuscate, @@ -537,6 +538,7 @@ def gen_model_code(model_codegen_dir, "--embed_model_data=%s" % embed_model_data, "--winograd=%s" % winograd, "--quantize=%s" % quantize, + "--quantize_large_weights=%s" % quantize_large_weights, "--quantize_range_file=%s" % quantize_range_file, "--change_concat_ranges=%s" % change_concat_ranges, "--obfuscate=%s" % obfuscate,