未验证 提交 611da7fc 编写于 作者: C Chang Xu 提交者: GitHub

Update PostQuantTraining zero size (#49868)

上级 d9d47dc6
......@@ -789,7 +789,7 @@ class PostTrainingQuantization:
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
......@@ -843,7 +843,7 @@ class PostTrainingQuantization:
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
......@@ -899,7 +899,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
......@@ -940,7 +940,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
......@@ -975,7 +975,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
min_value = float(np.min(var_tensor))
......@@ -992,7 +992,7 @@ class PostTrainingQuantization:
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if (not var_tensor.any()) or (
if (var_tensor.size == 0) or (
var_name not in self._sampling_act_histogram
):
self._zero_size_var_names.add(var_name)
......@@ -1031,7 +1031,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
......@@ -1094,7 +1094,7 @@ class PostTrainingQuantization:
'''
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
if var_tensor.size == 0:
self._zero_size_var_names.add(var_name)
continue
var_tensor = np.abs(var_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册