未验证 提交 18ed3324 编写于 作者: C ceci3 提交者: GitHub

[cherry pick] fix smac output_dir (#1248)

上级 75f477ff
......@@ -33,7 +33,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
# 定义DataSet
class ImageNetDataset(DatasetFolder):
......@@ -65,7 +65,7 @@ ac = AutoCompression(
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="output",
config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)},
config={'Quantization': {}, "HyperParameterOptimization": {'max_quant_count': 5}},
train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress()
......
......@@ -17,6 +17,7 @@ import os
import sys
import math
import time
from time import gmtime, strftime
import numpy as np
import shutil
import paddle
......@@ -40,7 +41,7 @@ from paddleslim.quant import quant_post
_logger = get_logger(__name__, level=logging.INFO)
SMAC_TMP_FILE_PATTERN = "smac3-output*"
SMAC_TMP_FILE_PATTERN = "smac3-output_"
def remove(path):
......@@ -496,6 +497,9 @@ def quant_post_hpo(
cs.add_hyperparameters(hyper_params)
s_datetime = strftime("%Y-%m-%d-%H:%M:%S", gmtime())
smac_output_dir = SMAC_TMP_FILE_PATTERN + s_datetime
scenario = Scenario({
"run_obj": "quality", # we optimize quality (alternative runtime)
"runcount-limit":
......@@ -503,7 +507,9 @@ def quant_post_hpo(
"cs": cs, # configuration space
"deterministic": "True",
"limit_resources": "False",
"memory_limit": 4096 # adapt this to reasonable value for your hardware
"memory_limit":
4096, # adapt this to reasonable value for your hardware
"output_dir": smac_output_dir # output_dir
})
# To optimize, we pass the function to the SMAC-object
smac = SMAC4HPO(
......@@ -523,5 +529,5 @@ def quant_post_hpo(
inc_value = smac.get_tae_runner().run(incumbent, 1)[1]
_logger.info("Optimized Value: %.8f" % inc_value)
shutil.rmtree(g_quant_model_cache_path)
remove(SMAC_TMP_FILE_PATTERN)
remove(smac_output_dir)
_logger.info("Quantization completed.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册