未验证 提交 cfd49acc 编写于 作者: X XGZhang 提交者: GitHub

fix a quantization bug (#34647)

上级 4f4662b0
...@@ -578,6 +578,7 @@ class PostTrainingQuantization(object): ...@@ -578,6 +578,7 @@ class PostTrainingQuantization(object):
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3 s = 0.3
if var_name not in self._best_mse_loss: if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf') self._best_mse_loss[var_name] = float('inf')
......
...@@ -1312,6 +1312,7 @@ class QuantizationFreezePass(object): ...@@ -1312,6 +1312,7 @@ class QuantizationFreezePass(object):
assert self._is_float( assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % ( scale_v), 'The scale of parameter %s is not a float.' % (
original_var_name) original_var_name)
scale_v = 1e-8 if scale_v == 0.0 else scale_v
max_range *= param_range / scale_v max_range *= param_range / scale_v
else: else:
max_range *= act_range max_range *= act_range
...@@ -1413,6 +1414,7 @@ class QuantizationFreezePass(object): ...@@ -1413,6 +1414,7 @@ class QuantizationFreezePass(object):
x[:, i] = _clip(x[:, i], s) x[:, i] = _clip(x[:, i], s)
x[:, i] = np.round(x[:, i] / s * bnt) x[:, i] = np.round(x[:, i] / s * bnt)
else: else:
scale = 1e-8 if scale == 0.0 else scale
x = _clip(x, scale) x = _clip(x, scale)
x = np.round(x / scale * bnt) x = np.round(x / scale * bnt)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册