未验证 提交 22720a15 编写于 作者: C cc 提交者: GitHub

Fix post quant save bug, test=develop (#25370)

上级 ea7e5325
......@@ -562,8 +562,9 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel()
save_path = os.path.join(self._cache_dir,
var_name + "_" + str(iter) + ".npy")
save_path = os.path.join(
self._cache_dir,
var_name.replace("/", ".") + "_" + str(iter) + ".npy")
np.save(save_path, var_tensor)
else:
for var_name in self._quantized_act_var_name:
......@@ -598,7 +599,7 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name:
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)]
if re.match(var_name.replace("/", ".") + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册