未验证 提交 52fba48d 编写于 作者: W whs 提交者: GitHub

Refine config API in ACT (#1178)

1. Make config of ACT support path, set, list and tuple
2. Refine the implementation of loading config
3. Add wrapper for dataloader to return data in dict format
上级 4528c8da
......@@ -129,7 +129,8 @@ def eval():
def main():
global global_config
_, _, global_config = load_slim_config(FLAGS.config_path)
all_config = load_slim_config(FLAGS.config_path)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
dataset = reader_cfg['EvalDataset']
......
......@@ -128,8 +128,9 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
def main():
global global_config
compress_config, train_config, global_config = load_slim_config(
FLAGS.config_path)
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
......@@ -167,11 +168,9 @@ def main():
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir,
strategy_config=compress_config,
train_config=train_config,
config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func)
ac.compress()
......
......@@ -100,7 +100,6 @@ if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
paddle.enable_static()
compress_config, train_config, _ = load_config(args.config_path)
data_dir = args.data_dir
if args.image_reader_type == 'paddle':
......@@ -121,8 +120,7 @@ if __name__ == '__main__':
model_filename=args.model_filename,
params_filename=args.params_filename,
save_dir=args.save_dir,
strategy_config=compress_config,
train_config=train_config,
config=args.config_path,
train_dataloader=train_dataloader,
eval_callback=eval_function,
eval_dataloader=reader_wrapper(
......
......@@ -242,7 +242,6 @@ if __name__ == '__main__':
print_arguments(args)
paddle.enable_static()
compress_config, train_config, _ = load_config(args.config_path)
if train_config is not None:
train_config.optimizer_builder[
'apply_decay_param_fun'] = apply_decay_param_fun
......@@ -256,8 +255,7 @@ if __name__ == '__main__':
model_filename=args.model_filename,
params_filename=args.params_filename,
save_dir=args.save_dir,
strategy_config=compress_config,
train_config=train_config,
config=args.config_path,
train_dataloader=train_dataloader,
eval_callback=eval_function if compress_config is None or
'HyperParameterOptimization' not in compress_config else
......
......@@ -114,7 +114,8 @@ python run.py \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/pp_humanseg/pp_humanseg_auto.yaml' \
--strategy_config='configs/pp_humanseg/pp_humanseg_auto.yaml' \
--dataset_config='configs/dataset/humanseg_dataset.yaml' \
--deploy_hardware='SD710'
# 多卡启动
......@@ -124,7 +125,8 @@ python -m paddle.distributed.launch run.py \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/pp_humanseg/pp_humanseg_auto.yaml' \
--strategy_config='configs/pp_humanseg/pp_humanseg_auto.yaml' \
--dataset_config='configs/dataset/humanseg_dataset.yaml' \
--deploy_hardware='SD710'
```
- 自行配置稀疏参数进行非结构化稀疏和蒸馏训练,配置参数含义详见[自动压缩超参文档](https://github.com/PaddlePaddle/PaddleSlim/blob/27dafe1c722476f1b16879f7045e9215b6f37559/demo/auto_compression/hyperparameter_tutorial.md)。具体命令如下所示:
......@@ -136,7 +138,8 @@ python run.py \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/pp_humanseg/pp_humanseg_sparse.yaml'
--strategy_config='configs/pp_humanseg/pp_humanseg_sparse.yaml' \
--dataset_config='configs/dataset/humanseg_dataset.yaml'
# 多卡启动
export CUDA_VISIBLE_DEVICES=0,1
......@@ -145,7 +148,8 @@ python -m paddle.distributed.launch run.py \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/pp_humanseg/pp_humanseg_sparse.yaml'
--strategy_config='configs/pp_humanseg/pp_humanseg_sparse.yaml' \
--dataset_config='configs/dataset/humanseg_dataset.yaml'
```
- 自行配置量化参数进行量化和蒸馏训练,配置参数含义详见[自动压缩超参文档](https://github.com/PaddlePaddle/PaddleSlim/blob/27dafe1c722476f1b16879f7045e9215b6f37559/demo/auto_compression/hyperparameter_tutorial.md)。具体命令如下所示:
......@@ -157,7 +161,8 @@ python run.py \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/pp_humanseg/pp_humanseg_qat.yaml'
--strategy_config='configs/pp_humanseg/pp_humanseg_qat.yaml' \
--dataset_config='configs/dataset/humanseg_dataset.yaml'
# 多卡启动
export CUDA_VISIBLE_DEVICES=0,1
......@@ -166,7 +171,8 @@ python -m paddle.distributed.launch run.py \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/pp_humanseg/pp_humanseg_qat.yaml'
--strategy_config='configs/pp_humanseg/pp_humanseg_qat.yaml' \
--dataset_config='configs/dataset/humanseg_dataset.yaml'
```
压缩完成后会在`save_dir`中产出压缩好的预测模型,可直接预测部署。
......
Global:
reader_config: configs/dataset/pp_humanseg_dataset.yaml
TrainConfig:
epochs: 14
eval_iter: 400
......
Global:
reader_config: configs/dataset/pp_humanseg_lite.yaml
Distillation:
alpha: 1.0
loss: l2
......
Global:
reader_config: configs/dataset/pp_humanseg_lite.yaml
Distillation:
alpha: 1.0
loss: l2
......
Global:
reader_config: configs/dataset/cityscapes_1024x512_scale1.0.yml
TrainConfig:
epochs: 14
eval_iter: 90
......
Global:
reader_config: configs/dataset/cityscapes_1024x512_scale1.0.yml
Distillation:
alpha: 1.0
loss: l2
......
Global:
reader_config: configs/dataset/cityscapes_1024x512_scale1.0.yml
Distillation:
alpha: 1.0
loss: l2
......
......@@ -3,9 +3,9 @@ import argparse
import random
import paddle
import numpy as np
from paddleseg.cvlibs import Config
from paddleseg.cvlibs import Config as PaddleSegDataConfig
from paddleseg.utils import worker_init_fn
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression import AutoCompression
from paddleseg.core.infer import reverse_transform
from paddleseg.utils import metrics
......@@ -34,10 +34,15 @@ def parse_args():
default=None,
help="directory to save compressed model.")
parser.add_argument(
'--config_path',
'--strategy_config',
type=str,
default=None,
help="path of compression strategy config.")
parser.add_argument(
'--dataset_config',
type=str,
default=None,
help="path of dataset config.")
parser.add_argument(
'--deploy_hardware',
type=str,
......@@ -148,13 +153,15 @@ if __name__ == '__main__':
args = parse_args()
compress_config, train_config, global_config = load_config(args.config_path)
cfg = Config(global_config['reader_config'])
train_dataset = cfg.train_dataset
eval_dataset = cfg.val_dataset
# step1: load dataset config and create dataloader
data_cfg = PaddleSegDataConfig(args.dataset_config)
train_dataset = data_cfg.train_dataset
eval_dataset = data_cfg.val_dataset
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
train_dataset,
batch_size=data_cfg.batch_size,
shuffle=True,
drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......@@ -163,16 +170,16 @@ if __name__ == '__main__':
worker_init_fn=worker_init_fn)
train_dataloader = reader_wrapper(train_loader)
# set auto_compression
# step2: create and instance of AutoCompression
ac = AutoCompression(
model_dir=args.model_dir,
model_filename=args.model_filename,
params_filename=args.params_filename,
save_dir=args.save_dir,
strategy_config=compress_config,
train_config=train_config,
config=args.strategy_config,
train_dataloader=train_dataloader,
eval_callback=eval_function,
deploy_hardware=args.deploy_hardware)
# step3: start the compression job
ac.compress()
......@@ -18,6 +18,7 @@ import platform
from ..common import get_logger
from .utils.predict import predict_compressed_model, with_variable_shape
from .strategy_config import *
from paddleslim.analysis import TableLatencyPredictor
_logger = get_logger(__name__, level=logging.INFO)
......
......@@ -22,6 +22,7 @@ import shutil
from time import gmtime, strftime
import platform
import paddle
import itertools
import paddle.distributed.fleet as fleet
from ..quant.quanter import convert, quant_post
from ..common.recover_program import recover_inference_program
......@@ -31,7 +32,9 @@ from ..analysis import TableLatencyPredictor
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes
from .strategy_config import TrainConfig, ProgramInfo, merge_config
from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config
from .config_helpers import load_config, extract_strategy_config, extract_train_config
from .utils.predict import with_variable_shape
from .utils import get_feed_vars, wrap_dataloader
_logger = get_logger(__name__, level=logging.INFO)
......@@ -49,9 +52,8 @@ class AutoCompression:
params_filename,
save_dir,
train_dataloader,
config=None,
input_shapes=None,
train_config=None,
strategy_config=None,
target_speedup=None,
eval_callback=None,
eval_dataloader=None,
......@@ -127,16 +129,29 @@ class AutoCompression:
self.final_dir = save_dir
if not os.path.exists(self.final_dir):
os.makedirs(self.final_dir)
self.strategy_config = strategy_config
self.train_dataloader = train_dataloader
# load config
assert type(config) in [
dict, str, set, list, tuple
], f"The type of config should be in [dict, str, set, list, tuple] but got {type(config)}"
if isinstance(config, str):
config = load_config(config)
self.strategy_config = extract_strategy_config(config)
self.train_config = extract_train_config(config)
# prepare dataloader
self.feed_vars = get_feed_vars(model_dir, model_filename,
params_filename)
self.train_dataloader = wrap_dataloader(train_dataloader,
self.feed_vars)
self.eval_dataloader = wrap_dataloader(eval_dataloader, self.feed_vars)
if eval_dataloader is None:
eval_dataloader = self._get_eval_dataloader(self.train_dataloader)
self.target_speedup = target_speedup
self.eval_function = eval_callback
self.deploy_hardware = deploy_hardware
if eval_dataloader is None:
eval_dataloader = self._get_eval_dataloader(train_dataloader)
self.eval_dataloader = eval_dataloader
paddle.enable_static()
self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type(self._exe, model_dir,
......@@ -158,6 +173,7 @@ class AutoCompression:
self.model_dir = infer_shape_model
self.model_filename = "infered_shape.pdmodel"
self.params_filename = "infered_shape.pdiparams"
if self.strategy_config is None:
strategy_config = prepare_strategy(
self._exe, self._places, self.model_dir, self.model_filename,
......
......@@ -12,39 +12,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from .strategy_config import *
import os
from paddleslim.auto_compression.strategy_config import *
__all__ = ['save_config', 'load_config']
def load_config(config_path):
def load_config(config):
"""Load configurations from yaml file into dict.
Fields validation is skipped for loading some custom information.
Args:
config(str): The path of configuration file.
Returns:
dict: A dict storing configuration information.
"""
convert yaml to dict config.
"""
f = open(config_path, 'r')
cfg = yaml.load(f, Loader=yaml.FullLoader)
f.close()
if config is None:
return None
assert isinstance(
config,
str), f"config should be str but got type(config)={type(config)}"
assert os.path.exists(config) and os.path.isfile(
config), f"{config} not found or it is not a file."
with open(config) as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
return cfg
global_config = {}
if 'Global' in cfg:
for g_key, g_value in cfg["Global"].items():
global_config[g_key] = g_value
cfg.pop('Global')
compress_config = {}
for key, value in cfg.items():
default_key = eval(key)(**value) if value is not None else eval(key)()
compress_config[key] = default_key
def extract_strategy_config(config):
"""Extract configuration items of strategies from file or dict.
And fields validation is enable.
Args:
config(str, dict): The path of configuration file or a dict storing information about strategies.
Returns:
dict: The key is the name of strategy and the value is an instance of paddleslim.auto_compression.BaseStrategy.
"""
if config is None:
return None
if isinstance(config, str):
config = load_config(config)
if compress_config.get('TrainConfig') != None:
train_config = compress_config.pop('TrainConfig')
else:
train_config = None
compress_config = {}
if isinstance(config, dict):
for key, value in config.items():
if key in SUPPORTED_CONFIG:
compress_config[key] = eval(key)(**value) if isinstance(
value, dict) else eval(key)()
elif type(config) in [set, list, tuple]:
for key in config:
assert isinstance(key, str)
if key in SUPPORTED_CONFIG:
compress_config[key] = eval(key)()
if len(compress_config) == 0:
compress_config = None
return compress_config
return compress_config, train_config, global_config
def extract_train_config(config):
"""Extract configuration items of training from file or dict.
And fields validation is enable.
Args:
config(str, dict): The path of configuration file or a dict storing information about training.
Returns:
An instance of paddleslim.auto_compression.TrainConfig
"""
if config is None:
return None
if isinstance(config, str):
config = load_config(config)
if isinstance(config, dict):
for key, value in config.items():
if key == TRAIN_CONFIG_NAME:
return TrainConfig(
**value) if value is not None else TrainConfig()
# return default training config when it is not set
return TrainConfig()
def save_config(config, config_path):
......
......@@ -45,12 +45,13 @@ def _create_lr_scheduler(train_config):
def _create_optimizer(train_config):
"""create optimizer"""
if 'optimizer_builder' not in train_config:
train_config['optimizer_builder'] = {'optimizer': {'type': 'SGD'}}
optimizer_builder = train_config['optimizer_builder']
assert isinstance(
optimizer_builder, dict
), f"Value of 'optimizer_builder' in train_config should be dict but got {type(optimizer_builder)}"
if 'grad_clip' in optimizer_builder:
g_clip_params = optimizer_builder['grad_clip']
g_clip_type = g_clip_params.pop('type')
......@@ -423,11 +424,13 @@ def build_prune_program(executor,
from ..prune import Pruner
pruner = Pruner(config["criterion"])
params = []
original_shapes = {}
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block().all_parameters():
if config['prune_params_name'] is not None and param.name in config[
'prune_params_name']:
params.append(param.name)
original_shapes[param.name] = param.shape
pruned_program, _, _ = pruner.prune(
train_program_info.program,
......@@ -435,6 +438,15 @@ def build_prune_program(executor,
params=params,
ratios=[config['pruned_ratio']] * len(params),
place=place)
_logger.info(
"####################channel pruning##########################")
for param in pruned_program.global_block().all_parameters():
if param.name in original_shapes:
_logger.info(
f"{param.name}, from {original_shapes[param.name]} to {param.shape}"
)
_logger.info(
"####################channel pruning end##########################")
train_program_info.program = pruned_program
elif strategy.startswith('asp'):
......
......@@ -15,13 +15,43 @@
from collections import namedtuple
__all__ = [
"Quantization", "Distillation", "MultiTeacherDistillation", \
"HyperParameterOptimization", "ChannelPrune", "UnstructurePrune", \
"TransformerPrune", "ASPPrune", "merge_config", "ProgramInfo", "TrainConfig",
"BaseStrategy",
"Quantization",
"Distillation",
"MultiTeacherDistillation",
"HyperParameterOptimization",
"ChannelPrune",
"UnstructurePrune",
"TransformerPrune",
"ASPPrune",
"merge_config",
"ProgramInfo",
"TrainConfig",
"SUPPORTED_CONFIG",
"TRAIN_CONFIG_NAME",
]
SUPPORTED_CONFIG = [
"Quantization",
"Distillation",
"MultiTeacherDistillation",
"HyperParameterOptimization",
"ChannelPrune",
"UnstructurePrune",
"TransformerPrune",
"ASPPrune",
"TrainConfig",
]
TRAIN_CONFIG_NAME = "TrainConfig"
class BaseStrategy:
def __init__(self, name):
self.name = name
class Quantization:
class Quantization(BaseStrategy):
def __init__(self,
quantize_op_types=[
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
......@@ -53,6 +83,7 @@ class Quantization:
for_tensorrt(bool): If True, 'quantize_op_types' will be TENSORRT_OP_TYPES. Default: False.
is_full_quantize(bool): If True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES. Default: False.
"""
super(Quantization, self).__init__("Quantization")
self.quantize_op_types = quantize_op_types
self.weight_bits = weight_bits
self.activation_bits = activation_bits
......@@ -68,7 +99,7 @@ class Quantization:
self.is_full_quantize = is_full_quantize
class Distillation:
class Distillation(BaseStrategy):
def __init__(self,
loss='l2',
node=[],
......@@ -86,6 +117,7 @@ class Distillation:
teacher_model_filename(str, optional): The name of teacher model file. If parameters are saved in separate files, set it as 'None'. Default: 'None'.
teacher_params_filename(str, optional): The name of teacher params file. When all parameters are saved in a single file, set it as filename. If parameters are saved in separate files, set it as 'None'. Default : 'None'.
"""
super(Distillation, self).__init__("Distillation")
self.loss = loss
self.node = node
self.alpha = alpha
......@@ -120,7 +152,7 @@ class MultiTeacherDistillation:
self.teacher_params_filename = teacher_params_filename
class HyperParameterOptimization:
class HyperParameterOptimization(BaseStrategy):
def __init__(self,
ptq_algo=["KL", "hist", "avg", "mse"],
bias_correct=[True, False],
......@@ -138,6 +170,7 @@ class HyperParameterOptimization:
batch_num(list(int)): The upper and lower bounds of batch number, the real batch number is uniform sampling in this bounds.
max_quant_count(int): Max number of model quantization. Default: 20.
"""
super(HyperParameterOptimization, self).__init__("HPO_PTQ")
self.ptq_algo = ptq_algo
self.bias_correct = bias_correct
self.weight_quantize_type = weight_quantize_type
......@@ -224,7 +257,9 @@ class TrainConfig:
epochs=None,
train_iter=None,
learning_rate=0.02,
optimizer_builder={'optimizer': 'SGD'},
optimizer_builder={'optimizer': {
'type': 'SGD'
}},
eval_iter=1000,
logging_iter=10,
origin_metric=None,
......
......@@ -14,5 +14,8 @@
from __future__ import absolute_import
from .predict import predict_compressed_model
from .dataloader import *
from . import dataloader
__all__ = ["predict_compressed_model"]
__all__ += dataloader.__all__
import os
import time
import numpy as np
import paddle
from collections.abc import Iterable
__all__ = ["wrap_dataloader", "get_feed_vars"]
def get_feed_vars(model_dir, model_filename, params_filename):
"""Get feed vars of model.
"""
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
model_dir,
exe,
model_filename=model_filename,
params_filename=params_filename))
return feed_target_names
def wrap_dataloader(dataloader, names):
"""Create a wrapper of dataloader if the data returned by the dataloader is not a dict.
And the names will be the keys of dict returned by the wrapper.
"""
if dataloader is None:
return dataloader
assert isinstance(dataloader, paddle.io.DataLoader)
assert len(dataloader) > 0
data = next(dataloader())
if isinstance(data, dict):
return dataloader
if isinstance(data, Iterable):
assert len(data) == len(
names
), f"len(data) == len(names), but got len(data): {len(data)} and len(names): {len(names)}"
else:
assert len(
names
) == 1, f"The length of name should 1 when data is not Iterable but got {len(names)}"
def gen():
for i, data in enumerate(dataloader()):
if not isinstance(data, Iterable):
data = [data]
yield dict((name_, np.array(data_))
for name_, data_ in zip(names, data))
return gen
# For unittests
Quantization:
quantize_op_types:
- conv2d
- depthwise_conv2d
Distillation:
alpha: 1.0
loss: l2
TrainConfig:
epochs: 1
eval_iter: 1070
learning_rate: 2.0e-5
optimizer_builder:
optimizer:
type: AdamW
weight_decay: 0.01
origin_metric: 0.7403
import sys
import os
sys.path.append("../")
import unittest
import tempfile
import paddle
import unittest
import numpy as np
from static_case import StaticCase
from paddle.io import Dataset
from paddleslim.auto_compression import AutoCompression
from paddleslim.auto_compression.config_helpers import load_config
class RandomEvalDataset(Dataset):
def __init__(self, num_samples, image_shape=[3, 32, 32], class_num=10):
self.num_samples = num_samples
self.image_shape = image_shape
self.class_num = class_num
def __getitem__(self, idx):
image = np.random.random(self.image_shape).astype('float32')
return image
def __len__(self):
return self.num_samples
class ACTBase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ACTBase, self).__init__(*args, **kwargs)
paddle.enable_static()
self.tmpdir = tempfile.TemporaryDirectory(prefix="test_")
self.infer_model_dir = os.path.join(self.tmpdir.name, "infer")
self.create_program()
self.create_dataloader()
def create_program(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(
name='data', shape=[-1, 3, 32, 32], dtype='float32')
tmp = paddle.static.nn.conv2d(
input=data, num_filters=2, filter_size=3)
out = paddle.static.nn.conv2d(
input=tmp, num_filters=2, filter_size=3)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(startup_program)
paddle.static.save_inference_model(
self.infer_model_dir, [data], [out], exe, program=main_program)
print(f"saved infer model to [{self.infer_model_dir}]")
def create_dataloader(self):
# define a random dataset
self.eval_dataset = RandomEvalDataset(32)
def __del__(self):
self.tmpdir.cleanup()
class TestYamlQATDistTrain(ACTBase):
def __init__(self, *args, **kwargs):
super(TestYamlQATDistTrain, self).__init__(*args, **kwargs)
def test_compress(self):
image = paddle.static.data(
name='data', shape=[-1, 3, 32, 32], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset, feed_list=[image], batch_size=4)
ac = AutoCompression(
model_dir=self.tmpdir.name,
model_filename="infer.pdmodel",
params_filename="infer.pdiparams",
save_dir="output",
config="./qat_dist_train.yaml",
train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress()
class TestSetQATDist(ACTBase):
def __init__(self, *args, **kwargs):
super(TestSetQATDist, self).__init__(*args, **kwargs)
def test_compress(self):
image = paddle.static.data(
name='data', shape=[-1, 3, 32, 32], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset, feed_list=[image], batch_size=4)
ac = AutoCompression(
model_dir=self.tmpdir.name,
model_filename="infer.pdmodel",
params_filename="infer.pdiparams",
save_dir="output",
config={"QAT", "Distillation"},
train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress()
class TestDictQATDist(ACTBase):
def __init__(self, *args, **kwargs):
super(TestDictQATDist, self).__init__(*args, **kwargs)
def test_compress(self):
config = load_config("./qat_dist_train.yaml")
image = paddle.static.data(
name='data', shape=[-1, 3, 32, 32], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset, feed_list=[image], batch_size=4)
ac = AutoCompression(
model_dir=self.tmpdir.name,
model_filename="infer.pdmodel",
params_filename="infer.pdiparams",
save_dir="output",
config=config,
train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册