From 9121d1fb298128859e5b37db35e0d8aadc4da0ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Fri, 31 Aug 2018 11:25:28 +0800 Subject: [PATCH] Fix dequantize output type --- mace/core/workspace.cc | 14 ++++++++------ mace/kernels/softmax.h | 2 +- mace/python/tools/converter_tool/base_converter.py | 2 ++ mace/python/tools/converter_tool/transformer.py | 4 ++++ tools/converter.py | 10 ++++++++-- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index c7401fcd..170070cd 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -178,12 +178,14 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, if (type == DeviceType::CPU && net_def.has_quantize_info()) { for (const auto &activation_info: net_def.quantize_info().activation_info()) { - MACE_CHECK(HasTensor(activation_info.tensor_name()), - "Quantize info exist for non-existed tensor", - activation_info.tensor_name()); - Tensor *tensor = GetTensor(activation_info.tensor_name()); - tensor->SetScale(activation_info.scale()); - tensor->SetZeroPoint(activation_info.zero_point()); + if (HasTensor(activation_info.tensor_name())) { + Tensor *tensor = GetTensor(activation_info.tensor_name()); + tensor->SetScale(activation_info.scale()); + tensor->SetZeroPoint(activation_info.zero_point()); + } else { + LOG(WARNING) << "Quantize info exists for non-existed tensor: " + << activation_info.tensor_name(); + } } } diff --git a/mace/kernels/softmax.h b/mace/kernels/softmax.h index 62e089c5..5de3ade1 100644 --- a/mace/kernels/softmax.h +++ b/mace/kernels/softmax.h @@ -123,7 +123,7 @@ struct SoftmaxFunctor { } }; -static const int kInputDeltaIntBits = 5; +static const int kInputDeltaIntBits = 6; static const int kSumExpIntBits = 12; template<> diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 2d61d3f9..04bf576c 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -368,6 +368,8 @@ class ConverterOption(object): TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE, # Transform finalization TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, + # for quantization entropy calibration use + TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, TransformerRule.SORT_BY_EXECUTION, ] if self._quantize: diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index ba882177..d4b9576c 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1638,6 +1638,7 @@ class Transformer(base_converter.ConverterInterface): output_shape = op_def.output_shape.add() output_shape.dims.extend( self._producer[output_node.name].output_shape[0].dims) + op_def.output_type.extend([mace_pb2.DT_FLOAT]) ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) @@ -1647,6 +1648,9 @@ class Transformer(base_converter.ConverterInterface): print("Add quantize tensor range") net = self._model range_file = self._option.quantize_range_file + if not range_file: + return + with open(range_file) as f: for line in f: tensor_name, minmax = line.split("@@") diff --git a/tools/converter.py b/tools/converter.py index 4eb6405c..ebe87405 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -481,13 +481,15 @@ def format_model_config(flags): DeviceType.CPU: 0.999, DeviceType.GPU: 0.995, DeviceType.HEXAGON: 0.930, + DeviceType.CPU + "_QUANTIZE": 0.980, } for k, v in six.iteritems(validation_threshold): if k.upper() == 'DSP': k = DeviceType.HEXAGON if k.upper() not in (DeviceType.CPU, DeviceType.GPU, - DeviceType.HEXAGON): + DeviceType.HEXAGON, + DeviceType.CPU + "_QUANTIZE"): raise argparse.ArgumentTypeError( 'Unsupported validation threshold runtime: %s' % k) threshold_dict[k.upper()] = v @@ -1251,6 +1253,10 @@ def run_specific_target(flags, configs, target_abi, model_config[YAMLKeyword.weight_file_path], model_config[YAMLKeyword.weight_sha256_checksum]) + validate_type = device_type + if model_config[YAMLKeyword.quantize] == 1: + validate_type = device_type + "_QUANTIZE" + sh_commands.validate_model( abi=target_abi, serialno=serial_num, @@ -1266,7 +1272,7 @@ def run_specific_target(flags, configs, target_abi, phone_data_dir=PHONE_DATA_DIR, input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa caffe_env=flags.caffe_env, - validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa + validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][validate_type]) # noqa if flags.report and flags.round > 0: tuned = is_tuned and device_type == DeviceType.GPU report_run_statistics( -- GitLab