未验证 提交 358b76fa 编写于 作者: C ceci3 提交者: GitHub

change threshold for ptq hpo (#1253)

上级 a256a9b6
......@@ -77,8 +77,8 @@ EXPERIENCE_STRATEGY_WITHOUT_LOSS = [
MAGIC_SPARSE_RATIO = 0.75
### TODO: 0.02 threshold maybe not suitable, need to check
### NOTE: reduce magic data to choose quantization aware training.
MAGIC_MAX_EMD_DISTANCE = 0.0002 #0.02
MAGIC_MIN_EMD_DISTANCE = 0.0001 #0.01
MAGIC_MAX_EMD_DISTANCE = 0.00002 #0.02
MAGIC_MIN_EMD_DISTANCE = 0.00001 #0.01
DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8'
DEFAULT_STRATEGY = 'origin_int8'
......
......@@ -144,7 +144,12 @@ def standardization(data):
"""standardization numpy array"""
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
sigma = 1e-13 if sigma == 0. else sigma
if isinstance(sigma, list) or isinstance(sigma, np.ndarray):
for idx, sig in enumerate(sigma):
if sig == 0.:
sigma[idx] = 1e-13
else:
sigma = 1e-13 if sigma == 0. else sigma
return (data - mu) / sigma
......@@ -241,18 +246,15 @@ def eval_quant_model():
if have_invalid_num(out_float) or have_invalid_num(out_quant):
continue
try:
out_float = standardization(out_float)
out_quant = standardization(out_quant)
except:
continue
out_float_list.append(out_float)
out_quant_list.append(out_quant)
out_float_list.append(list(out_float))
out_quant_list.append(list(out_quant))
valid_data_num += 1
if valid_data_num >= max_eval_data_num:
break
out_float_list = standardization(out_float_list)
out_quant_list = standardization(out_quant_list)
emd_sum = cal_emd_lose(out_float_list, out_quant_list,
out_len_sum / float(valid_data_num))
_logger.info("output diff: {}".format(emd_sum))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册