未验证 提交 dbdaa389 编写于 作者: C ceci3 提交者: GitHub

update readme (#1217)

上级 aa336f5f
......@@ -33,7 +33,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.auto_compression import AutoCompression
from paddleslim.auto_compression import AutoCompression, Quantization, HyperParameterOptimization
paddle.enable_static()
# 定义DataSet
class ImageNetDataset(DatasetFolder):
......@@ -65,7 +65,7 @@ ac = AutoCompression(
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="output",
config=None,
config={'Quantization': Quantization(), "HyperParameterOptimization": HyperParameterOptimization(max_quant_count=5)},
train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress()
......
......@@ -137,8 +137,14 @@ class AutoCompression:
# load config
if isinstance(config, str):
config = load_config(config)
self.strategy_config = extract_strategy_config(config)
self.train_config = extract_train_config(config)
self.strategy_config = extract_strategy_config(config)
self.train_config = extract_train_config(config)
elif isinstance(config, dict):
if 'TrainConfig' in config:
self.train_config = config.pop('TrainConfig')
else:
self.train_config = None
self.strategy_config = config
# prepare dataloader
self.feed_vars = get_feed_vars(self.model_dir, model_filename,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册