未验证 提交 dbdaa389 编写于 作者: C ceci3 提交者: GitHub

update readme (#1217)

上级 aa336f5f
...@@ -33,7 +33,7 @@ import paddle ...@@ -33,7 +33,7 @@ import paddle
from PIL import Image from PIL import Image
from paddle.vision.datasets import DatasetFolder from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms from paddle.vision.transforms import transforms
from paddleslim.auto_compression import AutoCompression from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization
paddle.enable_static() paddle.enable_static()
# 定义DataSet # 定义DataSet
class ImageNetDataset(DatasetFolder): class ImageNetDataset(DatasetFolder):
...@@ -65,7 +65,7 @@ ac = AutoCompression( ...@@ -65,7 +65,7 @@ ac = AutoCompression(
model_filename="inference.pdmodel", model_filename="inference.pdmodel",
params_filename="inference.pdiparams", params_filename="inference.pdiparams",
save_dir="output", save_dir="output",
config=None, config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)},
train_dataloader=train_loader, train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress() ac.compress()
......
...@@ -137,8 +137,14 @@ class AutoCompression: ...@@ -137,8 +137,14 @@ class AutoCompression:
# load config # load config
if isinstance(config, str): if isinstance(config, str):
config = load_config(config) config = load_config(config)
self.strategy_config = extract_strategy_config(config) self.strategy_config = extract_strategy_config(config)
self.train_config = extract_train_config(config) self.train_config = extract_train_config(config)
elif isinstance(config, dict):
if 'TrainConfig' in config:
self.train_config = config.pop('TrainConfig')
else:
self.train_config = None
self.strategy_config = config
# prepare dataloader # prepare dataloader
self.feed_vars = get_feed_vars(self.model_dir, model_filename, self.feed_vars = get_feed_vars(self.model_dir, model_filename,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册