未验证 提交 ce482365 编写于 作者: W whs 提交者: GitHub

Refine temp directory of ACT (#1145)

1. fix temp work dir not found error
2. move temp work dir to final directory formated by pid and datetime
3. remove temp work dir of ACT when job is done
4. remove temp work dir of quant_post_hpo when job is done
上级 285f3a71
...@@ -20,6 +20,7 @@ import inspect ...@@ -20,6 +20,7 @@ import inspect
import shutil import shutil
from collections import namedtuple from collections import namedtuple
from collections.abc import Iterable from collections.abc import Iterable
from time import gmtime, strftime
import platform import platform
import paddle import paddle
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
...@@ -67,7 +68,8 @@ class AutoCompression: ...@@ -67,7 +68,8 @@ class AutoCompression:
When all parameters are saved in a single file, set it When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files, as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'. set it as 'None'. Default : 'None'.
save_dir(str): The path to save compressed model. save_dir(str): The path to save compressed model. The models in this directory will be overwrited
after calling 'compress()' function.
train_data_loader(Python Generator, Paddle.io.DataLoader): The train_data_loader(Python Generator, Paddle.io.DataLoader): The
Generator or Dataloader provides train data, and it could Generator or Dataloader provides train data, and it could
return a batch every time. return a batch every time.
...@@ -108,11 +110,9 @@ class AutoCompression: ...@@ -108,11 +110,9 @@ class AutoCompression:
if params_filename == 'None': if params_filename == 'None':
params_filename = None params_filename = None
self.params_filename = params_filename self.params_filename = params_filename
base_path = os.path.basename(os.path.normpath(save_dir))
parent_path = os.path.abspath(os.path.join(save_dir, os.pardir))
base_path = base_path + '_temp'
self.save_dir = os.path.join(parent_path, base_path)
self.final_dir = save_dir self.final_dir = save_dir
if not os.path.exists(self.final_dir):
os.makedirs(self.final_dir)
self.strategy_config = strategy_config self.strategy_config = strategy_config
self.train_config = train_config self.train_config = train_config
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
...@@ -355,6 +355,13 @@ class AutoCompression: ...@@ -355,6 +355,13 @@ class AutoCompression:
return program_info return program_info
def compress(self): def compress(self):
# create a new temp directory in final dir
s_datetime = strftime("%Y-%m-%d-%H:%M:%S", gmtime())
tmp_base_name = "_".join(["tmp", str(os.getpid()), s_datetime])
self.tmp_dir = os.path.join(self.final_dir, tmp_base_name)
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
for strategy_idx, ( for strategy_idx, (
strategy, strategy,
config) in enumerate(zip(self._strategy, self._config)): config) in enumerate(zip(self._strategy, self._config)):
...@@ -371,7 +378,7 @@ class AutoCompression: ...@@ -371,7 +378,7 @@ class AutoCompression:
self.single_strategy_compress(quant_strategy[0], self.single_strategy_compress(quant_strategy[0],
quant_config[0], strategy_idx) quant_config[0], strategy_idx)
tmp_model_path = os.path.join( tmp_model_path = os.path.join(
self.save_dir, 'strategy_{}'.format(str(strategy_idx + 1))) self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1)))
final_model_path = os.path.join(self.final_dir) final_model_path = os.path.join(self.final_dir)
if not os.path.exists(final_model_path): if not os.path.exists(final_model_path):
os.makedirs(final_model_path) os.makedirs(final_model_path)
...@@ -382,6 +389,7 @@ class AutoCompression: ...@@ -382,6 +389,7 @@ class AutoCompression:
if paddle.distributed.get_rank() == 0: if paddle.distributed.get_rank() == 0:
shutil.move(tmp_model_file, final_model_file) shutil.move(tmp_model_file, final_model_file)
shutil.move(tmp_params_file, final_params_file) shutil.move(tmp_params_file, final_params_file)
shutil.rmtree(self.tmp_dir)
_logger.info( _logger.info(
"==> Finished the ACT process and the final model is saved in:{}". "==> Finished the ACT process and the final model is saved in:{}".
format(final_model_path)) format(final_model_path))
...@@ -395,7 +403,7 @@ class AutoCompression: ...@@ -395,7 +403,7 @@ class AutoCompression:
self._exe, self._exe,
model_dir=self.model_dir, model_dir=self.model_dir,
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.save_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader, data_loader=self.train_dataloader,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
...@@ -425,7 +433,7 @@ class AutoCompression: ...@@ -425,7 +433,7 @@ class AutoCompression:
self._places, self._places,
model_dir=self.model_dir, model_dir=self.model_dir,
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.save_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
train_dataloader=self.train_dataloader, train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader, eval_dataloader=self.eval_dataloader,
eval_function=self.eval_function, eval_function=self.eval_function,
...@@ -452,7 +460,7 @@ class AutoCompression: ...@@ -452,7 +460,7 @@ class AutoCompression:
model_dir = self.model_dir model_dir = self.model_dir
else: else:
model_dir = os.path.join( model_dir = os.path.join(
self.save_dir, 'strategy_{}'.format(str(strategy_idx))) self.tmp_dir, 'strategy_{}'.format(str(strategy_idx)))
[inference_program, feed_target_names, fetch_targets]= paddle.fluid.io.load_inference_model( \ [inference_program, feed_target_names, fetch_targets]= paddle.fluid.io.load_inference_model( \
dirname=model_dir, \ dirname=model_dir, \
...@@ -528,7 +536,7 @@ class AutoCompression: ...@@ -528,7 +536,7 @@ class AutoCompression:
if metric > best_metric: if metric > best_metric:
paddle.static.save( paddle.static.save(
program=test_program_info.program._program, program=test_program_info.program._program,
model_path=os.path.join(self.save_dir, model_path=os.path.join(self.tmp_dir,
'best_model')) 'best_model'))
best_metric = metric best_metric = metric
if self.metric_before_compressed is not None and float( if self.metric_before_compressed is not None and float(
...@@ -555,12 +563,12 @@ class AutoCompression: ...@@ -555,12 +563,12 @@ class AutoCompression:
test_program_info.program, test_program_info.program,
paddle.static.CompiledProgram) else test_program_info.program paddle.static.CompiledProgram) else test_program_info.program
if os.path.exists(os.path.join(self.save_dir, 'best_model.pdparams')): if os.path.exists(os.path.join(self.tmp_dir, 'best_model.pdparams')):
paddle.static.load(test_program, paddle.static.load(test_program,
os.path.join(self.save_dir, 'best_model')) os.path.join(self.tmp_dir, 'best_model'))
os.remove(os.path.join(self.save_dir, 'best_model.pdmodel')) os.remove(os.path.join(self.tmp_dir, 'best_model.pdmodel'))
os.remove(os.path.join(self.save_dir, 'best_model.pdopt')) os.remove(os.path.join(self.tmp_dir, 'best_model.pdopt'))
os.remove(os.path.join(self.save_dir, 'best_model.pdparams')) os.remove(os.path.join(self.tmp_dir, 'best_model.pdparams'))
if 'qat' in strategy: if 'qat' in strategy:
float_program, int8_program = convert(test_program_info.program._program, self._places, self._quant_config, \ float_program, int8_program = convert(test_program_info.program._program, self._places, self._quant_config, \
...@@ -568,7 +576,7 @@ class AutoCompression: ...@@ -568,7 +576,7 @@ class AutoCompression:
save_int8=True) save_int8=True)
test_program_info.program = float_program test_program_info.program = float_program
model_dir = os.path.join(self.save_dir, model_dir = os.path.join(self.tmp_dir,
'strategy_{}'.format(str(strategy_idx + 1))) 'strategy_{}'.format(str(strategy_idx + 1)))
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
os.makedirs(model_dir) os.makedirs(model_dir)
......
...@@ -24,6 +24,8 @@ import paddle.fluid as fluid ...@@ -24,6 +24,8 @@ import paddle.fluid as fluid
import logging import logging
import argparse import argparse
import functools import functools
import shutil
import glob
from scipy.stats import wasserstein_distance from scipy.stats import wasserstein_distance
# smac # smac
...@@ -36,6 +38,20 @@ from smac.scenario.scenario import Scenario ...@@ -36,6 +38,20 @@ from smac.scenario.scenario import Scenario
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.quant import quant_post from paddleslim.quant import quant_post
SMAC_TMP_FILE_PATTERN = "smac3-output*"
def remove(path):
"""Remove files or directories matched by regex.
Args:
path(str): regular expressions to match the files and directories.
"""
for p in glob.glob(path):
if os.path.isdir(p):
shutil.rmtree(p)
else:
os.remove(p)
class QuantConfig: class QuantConfig:
"""quant config""" """quant config"""
...@@ -472,7 +488,7 @@ def quant_post_hpo( ...@@ -472,7 +488,7 @@ def quant_post_hpo(
bias_correction=bias_correct, \ bias_correction=bias_correct, \
batch_size=batch_size, \ batch_size=batch_size, \
batch_nums=batch_num) batch_nums=batch_num)
shutil.rmtree(g_quant_model_cache_path)
return return
cs.add_hyperparameters(hyper_params) cs.add_hyperparameters(hyper_params)
...@@ -486,7 +502,6 @@ def quant_post_hpo( ...@@ -486,7 +502,6 @@ def quant_post_hpo(
"limit_resources": "False", "limit_resources": "False",
"memory_limit": 4096 # adapt this to reasonable value for your hardware "memory_limit": 4096 # adapt this to reasonable value for your hardware
}) })
# To optimize, we pass the function to the SMAC-object # To optimize, we pass the function to the SMAC-object
smac = SMAC4HPO( smac = SMAC4HPO(
scenario=scenario, rng=np.random.RandomState(42), tae_runner=quantize) scenario=scenario, rng=np.random.RandomState(42), tae_runner=quantize)
...@@ -504,4 +519,6 @@ def quant_post_hpo( ...@@ -504,4 +519,6 @@ def quant_post_hpo(
inc_value = smac.get_tae_runner().run(incumbent, 1)[1] inc_value = smac.get_tae_runner().run(incumbent, 1)[1]
print("Optimized Value: %.8f" % inc_value) print("Optimized Value: %.8f" % inc_value)
shutil.rmtree(g_quant_model_cache_path)
remove(SMAC_TMP_FILE_PATTERN)
print("quantize completed") print("quantize completed")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册