From 88b23d32de1ddc2f080c8749935476720b963cf4 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 5 Jul 2022 11:09:55 +0800 Subject: [PATCH] fix smac output_dir (#1237) --- example/auto_compression/README.md | 4 ++-- paddleslim/quant/post_quant_hpo.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/example/auto_compression/README.md b/example/auto_compression/README.md index bce9764e..b761402f 100644 --- a/example/auto_compression/README.md +++ b/example/auto_compression/README.md @@ -33,7 +33,7 @@ import paddle from PIL import Image from paddle.vision.datasets import DatasetFolder from paddle.vision.transforms import transforms -from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization +from paddleslim.auto_compression import AutoCompression paddle.enable_static() # 定义DataSet class ImageNetDataset(DatasetFolder): @@ -65,7 +65,7 @@ ac = AutoCompression( model_filename="inference.pdmodel", params_filename="inference.pdiparams", save_dir="output", - config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)}, + config={'Quantization': {}, "HyperParameterOptimization": {'max_quant_count': 5}}, train_dataloader=train_loader, eval_dataloader=train_loader) # eval_function to verify accuracy ac.compress() diff --git a/paddleslim/quant/post_quant_hpo.py b/paddleslim/quant/post_quant_hpo.py index e92742d0..9f9275f8 100755 --- a/paddleslim/quant/post_quant_hpo.py +++ b/paddleslim/quant/post_quant_hpo.py @@ -17,6 +17,7 @@ import os import sys import math import time +from time import gmtime, strftime import numpy as np import shutil import paddle @@ -40,7 +41,7 @@ from paddleslim.quant import quant_post _logger = get_logger(__name__, level=logging.INFO) -SMAC_TMP_FILE_PATTERN = "smac3-output*" +SMAC_TMP_FILE_PATTERN = "smac3-output_" def remove(path): @@ -496,6 +497,9 @@ def quant_post_hpo( cs.add_hyperparameters(hyper_params) + s_datetime = strftime("%Y-%m-%d-%H:%M:%S", gmtime()) + smac_output_dir = SMAC_TMP_FILE_PATTERN + s_datetime + scenario = Scenario({ "run_obj": "quality", # we optimize quality (alternative runtime) "runcount-limit": @@ -503,7 +507,9 @@ def quant_post_hpo( "cs": cs, # configuration space "deterministic": "True", "limit_resources": "False", - "memory_limit": 4096 # adapt this to reasonable value for your hardware + "memory_limit": + 4096, # adapt this to reasonable value for your hardware + "output_dir": smac_output_dir # output_dir }) # To optimize, we pass the function to the SMAC-object smac = SMAC4HPO( @@ -523,5 +529,5 @@ def quant_post_hpo( inc_value = smac.get_tae_runner().run(incumbent, 1)[1] _logger.info("Optimized Value: %.8f" % inc_value) shutil.rmtree(g_quant_model_cache_path) - remove(SMAC_TMP_FILE_PATTERN) + remove(smac_output_dir) _logger.info("Quantization completed.") -- GitLab