未验证 提交 611da7fc 编写于 作者: C Chang Xu 提交者: GitHub

Update PostQuantTraining zero size (#49868)

上级 d9d47dc6
...@@ -789,7 +789,7 @@ class PostTrainingQuantization: ...@@ -789,7 +789,7 @@ class PostTrainingQuantization:
_logger.info("MSE searching stage ...") _logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
...@@ -843,7 +843,7 @@ class PostTrainingQuantization: ...@@ -843,7 +843,7 @@ class PostTrainingQuantization:
_logger.info("EMD searching stage ...") _logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
...@@ -899,7 +899,7 @@ class PostTrainingQuantization: ...@@ -899,7 +899,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
...@@ -940,7 +940,7 @@ class PostTrainingQuantization: ...@@ -940,7 +940,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
...@@ -975,7 +975,7 @@ class PostTrainingQuantization: ...@@ -975,7 +975,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
...@@ -992,7 +992,7 @@ class PostTrainingQuantization: ...@@ -992,7 +992,7 @@ class PostTrainingQuantization:
def _sample_histogram(self): def _sample_histogram(self):
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if (not var_tensor.any()) or ( if (var_tensor.size == 0) or (
var_name not in self._sampling_act_histogram var_name not in self._sampling_act_histogram
): ):
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
...@@ -1031,7 +1031,7 @@ class PostTrainingQuantization: ...@@ -1031,7 +1031,7 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
...@@ -1094,7 +1094,7 @@ class PostTrainingQuantization: ...@@ -1094,7 +1094,7 @@ class PostTrainingQuantization:
''' '''
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any(): if var_tensor.size == 0:
self._zero_size_var_names.add(var_name) self._zero_size_var_names.add(var_name)
continue continue
var_tensor = np.abs(var_tensor) var_tensor = np.abs(var_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册