校准数据集 超过200张卡死电脑的问题
Created by: zyxcambridge
314行
def _calculate_scale_factor(self):
'''
Calculate the scale factor of quantized variables.
'''
# apply channel_wise_abs_max quantization for weights
for var_name in self._quantized_weight_var_name:
data = self._sampling_data[var_name]
scale_factor_per_channel = []
for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
# apply kl quantization for activation
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
else:
for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
超过200张图片后,yolov3 离线量化,第三层,featuremap 也就是 sampling_data 总数据量会达到100多个G,直接回卡死电脑、