提交 0654a658 编写于 作者: 李寅 提交者: 赵奇可

Fix dequantize output type

上级 95f63291
...@@ -178,12 +178,14 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, ...@@ -178,12 +178,14 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
if (type == DeviceType::CPU && net_def.has_quantize_info()) { if (type == DeviceType::CPU && net_def.has_quantize_info()) {
for (const auto for (const auto
&activation_info: net_def.quantize_info().activation_info()) { &activation_info: net_def.quantize_info().activation_info()) {
MACE_CHECK(HasTensor(activation_info.tensor_name()), if (HasTensor(activation_info.tensor_name())) {
"Quantize info exist for non-existed tensor", Tensor *tensor = GetTensor(activation_info.tensor_name());
activation_info.tensor_name()); tensor->SetScale(activation_info.scale());
Tensor *tensor = GetTensor(activation_info.tensor_name()); tensor->SetZeroPoint(activation_info.zero_point());
tensor->SetScale(activation_info.scale()); } else {
tensor->SetZeroPoint(activation_info.zero_point()); LOG(WARNING) << "Quantize info exists for non-existed tensor: "
<< activation_info.tensor_name();
}
} }
} }
......
...@@ -123,7 +123,7 @@ struct SoftmaxFunctor<DeviceType::CPU, float> { ...@@ -123,7 +123,7 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
} }
}; };
static const int kInputDeltaIntBits = 5; static const int kInputDeltaIntBits = 6;
static const int kSumExpIntBits = 12; static const int kSumExpIntBits = 12;
template<> template<>
......
...@@ -368,6 +368,8 @@ class ConverterOption(object): ...@@ -368,6 +368,8 @@ class ConverterOption(object):
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE, TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
# Transform finalization # Transform finalization
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES, TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
# for quantization entropy calibration use
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
TransformerRule.SORT_BY_EXECUTION, TransformerRule.SORT_BY_EXECUTION,
] ]
if self._quantize: if self._quantize:
......
...@@ -1638,6 +1638,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1638,6 +1638,7 @@ class Transformer(base_converter.ConverterInterface):
output_shape = op_def.output_shape.add() output_shape = op_def.output_shape.add()
output_shape.dims.extend( output_shape.dims.extend(
self._producer[output_node.name].output_shape[0].dims) self._producer[output_node.name].output_shape[0].dims)
op_def.output_type.extend([mace_pb2.DT_FLOAT])
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
...@@ -1647,6 +1648,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1647,6 +1648,9 @@ class Transformer(base_converter.ConverterInterface):
print("Add quantize tensor range") print("Add quantize tensor range")
net = self._model net = self._model
range_file = self._option.quantize_range_file range_file = self._option.quantize_range_file
if not range_file:
return
with open(range_file) as f: with open(range_file) as f:
for line in f: for line in f:
tensor_name, minmax = line.split("@@") tensor_name, minmax = line.split("@@")
......
...@@ -481,13 +481,15 @@ def format_model_config(flags): ...@@ -481,13 +481,15 @@ def format_model_config(flags):
DeviceType.CPU: 0.999, DeviceType.CPU: 0.999,
DeviceType.GPU: 0.995, DeviceType.GPU: 0.995,
DeviceType.HEXAGON: 0.930, DeviceType.HEXAGON: 0.930,
DeviceType.CPU + "_QUANTIZE": 0.980,
} }
for k, v in six.iteritems(validation_threshold): for k, v in six.iteritems(validation_threshold):
if k.upper() == 'DSP': if k.upper() == 'DSP':
k = DeviceType.HEXAGON k = DeviceType.HEXAGON
if k.upper() not in (DeviceType.CPU, if k.upper() not in (DeviceType.CPU,
DeviceType.GPU, DeviceType.GPU,
DeviceType.HEXAGON): DeviceType.HEXAGON,
DeviceType.CPU + "_QUANTIZE"):
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
'Unsupported validation threshold runtime: %s' % k) 'Unsupported validation threshold runtime: %s' % k)
threshold_dict[k.upper()] = v threshold_dict[k.upper()] = v
...@@ -1251,6 +1253,10 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1251,6 +1253,10 @@ def run_specific_target(flags, configs, target_abi,
model_config[YAMLKeyword.weight_file_path], model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum]) model_config[YAMLKeyword.weight_sha256_checksum])
validate_type = device_type
if model_config[YAMLKeyword.quantize] == 1:
validate_type = device_type + "_QUANTIZE"
sh_commands.validate_model( sh_commands.validate_model(
abi=target_abi, abi=target_abi,
serialno=serial_num, serialno=serial_num,
...@@ -1266,7 +1272,7 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1266,7 +1272,7 @@ def run_specific_target(flags, configs, target_abi,
phone_data_dir=PHONE_DATA_DIR, phone_data_dir=PHONE_DATA_DIR,
input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa
caffe_env=flags.caffe_env, caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][validate_type]) # noqa
if flags.report and flags.round > 0: if flags.report and flags.round > 0:
tuned = is_tuned and device_type == DeviceType.GPU tuned = is_tuned and device_type == DeviceType.GPU
report_run_statistics( report_run_statistics(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册