提交 0654a658 编写于 作者: 李寅 提交者: 赵奇可

Fix dequantize output type

上级 95f63291
......@@ -178,12 +178,14 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
if (type == DeviceType::CPU && net_def.has_quantize_info()) {
for (const auto
&activation_info: net_def.quantize_info().activation_info()) {
MACE_CHECK(HasTensor(activation_info.tensor_name()),
"Quantize info exist for non-existed tensor",
activation_info.tensor_name());
Tensor *tensor = GetTensor(activation_info.tensor_name());
tensor->SetScale(activation_info.scale());
tensor->SetZeroPoint(activation_info.zero_point());
if (HasTensor(activation_info.tensor_name())) {
Tensor *tensor = GetTensor(activation_info.tensor_name());
tensor->SetScale(activation_info.scale());
tensor->SetZeroPoint(activation_info.zero_point());
} else {
LOG(WARNING) << "Quantize info exists for non-existed tensor: "
<< activation_info.tensor_name();
}
}
}
......
......@@ -123,7 +123,7 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
}
};
static const int kInputDeltaIntBits = 5;
static const int kInputDeltaIntBits = 6;
static const int kSumExpIntBits = 12;
template<>
......
......@@ -368,6 +368,8 @@ class ConverterOption(object):
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
# Transform finalization
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
# for quantization entropy calibration use
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
TransformerRule.SORT_BY_EXECUTION,
]
if self._quantize:
......
......@@ -1638,6 +1638,7 @@ class Transformer(base_converter.ConverterInterface):
output_shape = op_def.output_shape.add()
output_shape.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
op_def.output_type.extend([mace_pb2.DT_FLOAT])
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
......@@ -1647,6 +1648,9 @@ class Transformer(base_converter.ConverterInterface):
print("Add quantize tensor range")
net = self._model
range_file = self._option.quantize_range_file
if not range_file:
return
with open(range_file) as f:
for line in f:
tensor_name, minmax = line.split("@@")
......
......@@ -481,13 +481,15 @@ def format_model_config(flags):
DeviceType.CPU: 0.999,
DeviceType.GPU: 0.995,
DeviceType.HEXAGON: 0.930,
DeviceType.CPU + "_QUANTIZE": 0.980,
}
for k, v in six.iteritems(validation_threshold):
if k.upper() == 'DSP':
k = DeviceType.HEXAGON
if k.upper() not in (DeviceType.CPU,
DeviceType.GPU,
DeviceType.HEXAGON):
DeviceType.HEXAGON,
DeviceType.CPU + "_QUANTIZE"):
raise argparse.ArgumentTypeError(
'Unsupported validation threshold runtime: %s' % k)
threshold_dict[k.upper()] = v
......@@ -1251,6 +1253,10 @@ def run_specific_target(flags, configs, target_abi,
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum])
validate_type = device_type
if model_config[YAMLKeyword.quantize] == 1:
validate_type = device_type + "_QUANTIZE"
sh_commands.validate_model(
abi=target_abi,
serialno=serial_num,
......@@ -1266,7 +1272,7 @@ def run_specific_target(flags, configs, target_abi,
phone_data_dir=PHONE_DATA_DIR,
input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa
caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][validate_type]) # noqa
if flags.report and flags.round > 0:
tuned = is_tuned and device_type == DeviceType.GPU
report_run_statistics(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册