未验证 提交 22720a15 编写于 作者: C cc 提交者: GitHub

Fix post quant save bug, test=develop (#25370)

上级 ea7e5325
...@@ -562,8 +562,9 @@ class PostTrainingQuantization(object): ...@@ -562,8 +562,9 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel() var_tensor = var_tensor.ravel()
save_path = os.path.join(self._cache_dir, save_path = os.path.join(
var_name + "_" + str(iter) + ".npy") self._cache_dir,
var_name.replace("/", ".") + "_" + str(iter) + ".npy")
np.save(save_path, var_tensor) np.save(save_path, var_tensor)
else: else:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
...@@ -598,7 +599,7 @@ class PostTrainingQuantization(object): ...@@ -598,7 +599,7 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
sampling_data = [] sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \ filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)] if re.match(var_name.replace("/", ".") + '_[0-9]+.npy', f)]
for filename in filenames: for filename in filenames:
file_path = os.path.join(self._cache_dir, filename) file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path)) sampling_data.append(np.load(file_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册