diff --git a/demo/auto_compression/README.md b/demo/auto_compression/README.md index 4b8c68766b2c69157fa41e5fe3473a6cda7bd3a7..38f07794e6603b0bdec70729c0c47e1e9608cf2e 100644 --- a/demo/auto_compression/README.md +++ b/demo/auto_compression/README.md @@ -33,7 +33,7 @@ import paddle from PIL import Image from paddle.vision.datasets import DatasetFolder from paddle.vision.transforms import transforms -from paddleslim.auto_compression import AutoCompression +from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization paddle.enable_static() # 定义DataSet class ImageNetDataset(DatasetFolder): @@ -65,7 +65,7 @@ ac = AutoCompression( model_filename="inference.pdmodel", params_filename="inference.pdiparams", save_dir="output", - config=None, + config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)}, train_dataloader=train_loader, eval_dataloader=train_loader) # eval_function to verify accuracy ac.compress() diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index b39d67faef70d361080b3b9583fd3d45dc391572..b429354583e1ba4d78830c6257580d14ee56f2be 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -137,8 +137,14 @@ class AutoCompression: # load config if isinstance(config, str): config = load_config(config) - self.strategy_config = extract_strategy_config(config) - self.train_config = extract_train_config(config) + self.strategy_config = extract_strategy_config(config) + self.train_config = extract_train_config(config) + elif isinstance(config, dict): + if 'TrainConfig' in config: + self.train_config = config.pop('TrainConfig') + else: + self.train_config = None + self.strategy_config = config # prepare dataloader self.feed_vars = get_feed_vars(self.model_dir, model_filename,