From dbdaa3893b3ee52fe00f025ca74d8b425f635c14 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 1 Jul 2022 10:42:37 +0800 Subject: [PATCH] update readme (#1217) --- demo/auto_compression/README.md | 4 ++-- paddleslim/auto_compression/compressor.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/demo/auto_compression/README.md b/demo/auto_compression/README.md index 4b8c6876..38f07794 100644 --- a/demo/auto_compression/README.md +++ b/demo/auto_compression/README.md @@ -33,7 +33,7 @@ import paddle from PIL import Image from paddle.vision.datasets import DatasetFolder from paddle.vision.transforms import transforms -from paddleslim.auto_compression import AutoCompression +from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization paddle.enable_static() # 定义DataSet class ImageNetDataset(DatasetFolder): @@ -65,7 +65,7 @@ ac = AutoCompression( model_filename="inference.pdmodel", params_filename="inference.pdiparams", save_dir="output", - config=None, + config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)}, train_dataloader=train_loader, eval_dataloader=train_loader) # eval_function to verify accuracy ac.compress() diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index b39d67fa..b4293545 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -137,8 +137,14 @@ class AutoCompression: # load config if isinstance(config, str): config = load_config(config) - self.strategy_config = extract_strategy_config(config) - self.train_config = extract_train_config(config) + self.strategy_config = extract_strategy_config(config) + self.train_config = extract_train_config(config) + elif isinstance(config, dict): + if 'TrainConfig' in config: + self.train_config = config.pop('TrainConfig') + else: + self.train_config = None + self.strategy_config = config # prepare dataloader self.feed_vars = get_feed_vars(self.model_dir, model_filename, -- GitLab