From 22720a1535be0b699a84a826218dc6f45e7a353b Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 7 Jul 2020 11:14:06 +0800 Subject: [PATCH] Fix post quant save bug, test=develop (#25370) --- .../slim/quantization/post_training_quantization.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index a0b28243d2f..6dcbf9e1e27 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -562,8 +562,9 @@ class PostTrainingQuantization(object): for var_name in self._quantized_act_var_name: var_tensor = _load_variable_data(self._scope, var_name) var_tensor = var_tensor.ravel() - save_path = os.path.join(self._cache_dir, - var_name + "_" + str(iter) + ".npy") + save_path = os.path.join( + self._cache_dir, + var_name.replace("/", ".") + "_" + str(iter) + ".npy") np.save(save_path, var_tensor) else: for var_name in self._quantized_act_var_name: @@ -598,7 +599,7 @@ class PostTrainingQuantization(object): for var_name in self._quantized_act_var_name: sampling_data = [] filenames = [f for f in os.listdir(self._cache_dir) \ - if re.match(var_name + '_[0-9]+.npy', f)] + if re.match(var_name.replace("/", ".") + '_[0-9]+.npy', f)] for filename in filenames: file_path = os.path.join(self._cache_dir, filename) sampling_data.append(np.load(file_path)) -- GitLab