提交 a0c1f222 编写于 作者: 李寅

"Quantization enhancement:

Bring gain of top-1 accuracy (69.54% -> 70.20%)
上级 a74002cb
...@@ -7,7 +7,7 @@ class QuantizeStat(object): ...@@ -7,7 +7,7 @@ class QuantizeStat(object):
pass pass
@staticmethod @staticmethod
def run(log_file, percentile): def run(log_file, percentile, enhance, enhance_ratio):
res = {} res = {}
tensor_ranges = {} tensor_ranges = {}
with open(log_file) as log: with open(log_file) as log:
...@@ -22,12 +22,41 @@ class QuantizeStat(object): ...@@ -22,12 +22,41 @@ class QuantizeStat(object):
tensor_ranges[tensor_name][1].append(max_val) tensor_ranges[tensor_name][1].append(max_val)
for tensor_name in tensor_ranges: for tensor_name in tensor_ranges:
samples = len(tensor_ranges[tensor_name][0])
tensor_min = np.percentile(tensor_ranges[tensor_name][0], tensor_min = np.percentile(tensor_ranges[tensor_name][0],
percentile) percentile)
tensor_max = np.percentile(tensor_ranges[tensor_name][1], tensor_max = np.percentile(tensor_ranges[tensor_name][1],
100 - percentile) 100 - percentile)
assert tensor_min < tensor_max assert tensor_min < tensor_max
res[tensor_name] = (tensor_min, tensor_max) if not enhance or samples <= 1:
res[tensor_name] = (tensor_min, tensor_max)
else:
tensor_mins = np.sort(tensor_ranges[tensor_name][0])
tensor_maxs = np.sort(tensor_ranges[tensor_name][1])[::-1]
cur_min_idx = 0
cur_max_idx = 0
cur_min = tensor_min
cur_max = tensor_max
for i in xrange(samples):
if tensor_mins[i] + 0.1 > cur_max:
break
delta_range = (tensor_mins[i] - cur_min) / (cur_max - cur_min) # noqa
delta_quantile = float(i - cur_min_idx) / (samples - cur_min_idx) # noqa
if delta_quantile > 0 and delta_range / delta_quantile > enhance_ratio: # noqa
cur_min_idx = i
cur_min = tensor_mins[i]
if cur_min + 0.1 > tensor_maxs[i]:
break
delta_range = (cur_max - tensor_maxs[i]) / (cur_max - cur_min) # noqa
delta_quantile = float(i - cur_max_idx) / (samples - cur_max_idx) # noqa
if delta_quantile > 0 and delta_range / delta_quantile > enhance_ratio: # noqa
cur_max_idx = i
cur_max = tensor_maxs[i]
res[tensor_name] = (cur_min, cur_max)
return res return res
...@@ -42,10 +71,20 @@ if __name__ == '__main__': ...@@ -42,10 +71,20 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--percentile", "--percentile",
type=int, type=int,
default=5, default=0,
help="range percentile") help="range percentile")
parser.add_argument(
"--enhance",
action="store_true",
help="range percentile")
parser.add_argument(
"--enhance_ratio",
type=int,
default=10,
help="enhance ratio")
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
res = QuantizeStat.run(FLAGS.log_file, FLAGS.percentile) res = QuantizeStat.run(FLAGS.log_file, FLAGS.percentile, FLAGS.enhance,
FLAGS.enhance_ratio)
for tensor in res: for tensor in res:
print("%s@@%f,%f" % (tensor, res[tensor][0], res[tensor][1])) print("%s@@%f,%f" % (tensor, res[tensor][0], res[tensor][1]))
...@@ -564,9 +564,11 @@ def get_model_files(model_file_path, ...@@ -564,9 +564,11 @@ def get_model_files(model_file_path,
model_sha256_checksum, model_sha256_checksum,
model_output_dir, model_output_dir,
weight_file_path="", weight_file_path="",
weight_sha256_checksum=""): weight_sha256_checksum="",
quantize_range_file_path=""):
model_file = model_file_path model_file = model_file_path
weight_file = weight_file_path weight_file = weight_file_path
quantize_range_file = quantize_range_file_path
if model_file_path.startswith("http://") or \ if model_file_path.startswith("http://") or \
model_file_path.startswith("https://"): model_file_path.startswith("https://"):
...@@ -598,7 +600,15 @@ def get_model_files(model_file_path, ...@@ -598,7 +600,15 @@ def get_model_files(model_file_path,
MaceLogger.error(ModuleName.MODEL_CONVERTER, MaceLogger.error(ModuleName.MODEL_CONVERTER,
"weight file sha256checksum not match") "weight file sha256checksum not match")
return model_file, weight_file if quantize_range_file_path.startswith("http://") or \
quantize_range_file_path.startswith("https://"):
quantize_range_file = \
model_output_dir + "/" + md5sum(quantize_range_file_path) \
+ ".range"
if not download_file(quantize_range_file_path, quantize_range_file):
MaceLogger.error(ModuleName.MODEL_CONVERTER,
"Model range file download failed.")
return model_file, weight_file, quantize_range_file
def convert_model(configs, cl_mem_type): def convert_model(configs, cl_mem_type):
...@@ -649,12 +659,14 @@ def convert_model(configs, cl_mem_type): ...@@ -649,12 +659,14 @@ def convert_model(configs, cl_mem_type):
else: else:
model_config[YAMLKeyword.cl_mem_type] = "image" model_config[YAMLKeyword.cl_mem_type] = "image"
model_file_path, weight_file_path = get_model_files( model_file_path, weight_file_path, quantize_range_file_path = \
model_config[YAMLKeyword.model_file_path], get_model_files(
model_config[YAMLKeyword.model_sha256_checksum], model_config[YAMLKeyword.model_file_path],
BUILD_DOWNLOADS_DIR, model_config[YAMLKeyword.model_sha256_checksum],
model_config[YAMLKeyword.weight_file_path], BUILD_DOWNLOADS_DIR,
model_config[YAMLKeyword.weight_sha256_checksum]) model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum],
model_config.get(YAMLKeyword.quantize_range_file, ""))
data_type = model_config[YAMLKeyword.data_type] data_type = model_config[YAMLKeyword.data_type]
# TODO(liuqi): support multiple subgraphs # TODO(liuqi): support multiple subgraphs
...@@ -683,7 +695,7 @@ def convert_model(configs, cl_mem_type): ...@@ -683,7 +695,7 @@ def convert_model(configs, cl_mem_type):
embed_model_data, embed_model_data,
model_config[YAMLKeyword.winograd], model_config[YAMLKeyword.winograd],
model_config[YAMLKeyword.quantize], model_config[YAMLKeyword.quantize],
model_config.get(YAMLKeyword.quantize_range_file, ""), quantize_range_file_path,
model_config[YAMLKeyword.change_concat_ranges], model_config[YAMLKeyword.change_concat_ranges],
model_config[YAMLKeyword.obfuscate], model_config[YAMLKeyword.obfuscate],
configs[YAMLKeyword.model_graph_format], configs[YAMLKeyword.model_graph_format],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册