未验证 提交 9b358cab 编写于 作者: I itminner 提交者: GitHub

add quant post hyper params search (#939)

* add rv1109 quant aware training

* quant post hyper param search

* quant post hyper param search
Co-authored-by: Nitminner <397809320@example.com>
Co-authored-by: Nwhs <wanghaoshuang@baidu.com>
Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 3fde095b
import os
import sys
import math
import time
import numpy as np
import paddle
import logging
import argparse
import functools
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.quant import quant_post_hpo
from utility import add_arguments, print_arguments
import imagenet_reader as reader
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./inference_model/MobileNet/", "model dir")
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
add_arg('model_filename', str, None, "model file name")
add_arg('params_filename', str, None, "params file name")
add_arg('max_model_quant_count', int, 30, "max model quant count")
def quantize(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
assert os.path.exists(args.model_path), "args.model_path doesn't exist"
assert os.path.isdir(args.model_path), "args.model_path must be a dir"
def reader_generator(imagenet_reader):
def gen():
for i, data in enumerate(imagenet_reader()):
image, label = data
image = np.expand_dims(image, axis=0)
yield image
return gen
exe = paddle.static.Executor(place)
quant_post_hpo(
exe,
place,
args.model_path,
args.save_path,
train_sample_generator=reader_generator(reader.train()),
eval_sample_generator=reader_generator(reader.val()),
model_filename=args.model_filename,
params_filename=args.params_filename,
save_model_filename='__model__',
save_params_filename='__params__',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type='channel_wise_abs_max',
runcount_limit=args.max_model_quant_count)
def main():
args = parser.parse_args()
print_arguments(args)
quantize(args)
if __name__ == '__main__':
paddle.enable_static()
main()
......@@ -31,6 +31,7 @@ try:
], "training-aware and post-training quant is not supported in 2.0 alpha version paddle"
from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic
from .quanter import quant_post, quant_post_only_weight
from .quant_post_hpo import quant_post_hpo
except Exception as e:
_logger.warning(e)
_logger.warning(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""quant post with hyper params search"""
import os
import cv2
import sys
import math
import time
import numpy as np
import shutil
import paddle
import paddle.fluid as fluid
import logging
import argparse
import functools
from scipy.stats import wasserstein_distance
# smac
from ConfigSpace.hyperparameters import CategoricalHyperparameter, \
UniformFloatHyperparameter, UniformIntegerHyperparameter
from smac.configspace import ConfigurationSpace
from smac.facade.smac_hpo_facade import SMAC4HPO
from smac.scenario.scenario import Scenario
from paddleslim.common import get_logger
from paddleslim.quant import quant_post
class QuantConfig:
"""quant config"""
def __init__(self,
executor,
place,
float_infer_model_path,
quantize_model_path,
train_sample_generator=None,
eval_sample_generator=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
"""QuantConfig init"""
self.executor = executor
self.place = place
self.float_infer_model_path = float_infer_model_path
self.quantize_model_path = quantize_model_path
self.train_sample_generator = train_sample_generator
self.eval_sample_generator = eval_sample_generator
self.model_filename = model_filename
self.params_filename = params_filename
self.save_model_filename = save_model_filename
self.save_params_filename = save_params_filename
self.scope = scope
self.quantizable_op_type = quantizable_op_type
self.is_full_quantize = is_full_quantize
self.weight_bits = weight_bits
self.activation_bits = activation_bits
self.weight_quantize_type = weight_quantize_type
self.optimize_model = optimize_model
self.is_use_cache_file = is_use_cache_file
self.cache_dir = cache_dir
g_quant_config = None
g_min_emd_loss = float('inf')
g_quant_model_cache_path = "quant_model_tmp"
def make_feed_dict(feed_target_names, data):
"""construct feed dictionary"""
feed_dict = {}
if len(feed_target_names) == 1:
feed_dict[feed_target_names[0]] = data
else:
for i in range(len(feed_target_names)):
feed_dict[feed_target_names[i]] = data[i]
return feed_dict
def standardization(data):
"""standardization numpy array"""
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
return (data - mu) / sigma
def cal_emd_lose(out_float_list, out_quant_list, out_len):
"""caculate earch move distance"""
emd_sum = 0
if out_len >= 3:
for index in range(len(out_float_list)):
emd_sum += wasserstein_distance(out_float_list[index],
out_quant_list[index])
else:
out_float = np.concatenate(out_float_list)
out_quant = np.concatenate(out_quant_list)
emd_sum += wasserstein_distance(out_float, out_quant)
emd_sum /= float(len(out_float_list))
return emd_sum
def have_invalid_num(np_arr):
"""check have invalid number in numpy array"""
have_invalid_num = False
for val in np_arr:
if math.isnan(val) or math.isinf(val):
have_invalid_num = True
break
return have_invalid_num
def convert_model_out_2_nparr(model_out):
"""convert model output to numpy array"""
if not isinstance(model_out, list):
model_out = [model_out]
out_list = []
for out in model_out:
out_list.append(np.array(out))
out_nparr = np.concatenate(out_list)
out_nparr = np.squeeze(out_nparr.flatten())
return out_nparr
def eval_quant_model():
"""Eval quant model accuracy.
Post quantization does not change the parameter value. Therefore, the closer the output distribution of the quantization model and the float model, the better the accuracy is maintained,
which has been verified in classification, detection, and nlp tasks. So the reward here is the earth mover distance between the output of the quantization model and the float model.
This distance measurement method is also verified on various tasks, and the stability is better than other distance measurement methods such as mse.
"""
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[infer_prog_float, feed_target_names_float, fetch_targets_float] = \
fluid.io.load_inference_model(dirname=g_quant_config.float_infer_model_path, \
model_filename=g_quant_config.model_filename, \
params_filename=g_quant_config.params_filename, \
executor=g_quant_config.executor)
with paddle.static.scope_guard(quant_scope):
[infer_prog_quant, feed_target_names_quant, fetch_targets_quant] = \
fluid.io.load_inference_model(dirname=g_quant_model_cache_path, \
model_filename=g_quant_config.save_model_filename, \
params_filename=g_quant_config.save_params_filename, \
executor=g_quant_config.executor)
out_float_list = []
out_quant_list = []
emd_sum = 0
out_len_sum = 0
valid_data_num = 0
max_eval_data_num = 200
for i, data in enumerate(g_quant_config.eval_sample_generator()):
with paddle.static.scope_guard(float_scope):
out_float = g_quant_config.executor.run(infer_prog_float, \
fetch_list=fetch_targets_float, feed=make_feed_dict(feed_target_names_float, data))
with paddle.static.scope_guard(quant_scope):
out_quant = g_quant_config.executor.run(infer_prog_quant, \
fetch_list=fetch_targets_quant, feed=make_feed_dict(feed_target_names_quant, data))
out_float = convert_model_out_2_nparr(out_float)
out_quant = convert_model_out_2_nparr(out_quant)
if len(out_float.shape) <= 0 or len(out_quant.shape) <= 0:
continue
min_len = min(out_float.shape[0], out_quant.shape[0])
out_float = out_float[:min_len]
out_quant = out_quant[:min_len]
out_len_sum += min_len
if have_invalid_num(out_float) or have_invalid_num(out_quant):
continue
try:
out_float = standardization(out_float)
out_quant = standardization(out_quant)
except:
continue
out_float_list.append(out_float)
out_quant_list.append(out_quant)
valid_data_num += 1
if valid_data_num >= max_eval_data_num:
break
emd_sum = cal_emd_lose(out_float_list, out_quant_list,
out_len_sum / float(valid_data_num))
print("output diff:", emd_sum)
return float(emd_sum)
def quantize(cfg):
"""model quantize job"""
algo = cfg["algo"]
hist_percent = cfg["hist_percent"]
bias_correct = cfg["bias_correct"]
batch_size = cfg["batch_size"]
batch_num = cfg["batch_num"]
quant_post( \
executor=g_quant_config.executor, \
scope=g_quant_config.scope, \
model_dir=g_quant_config.float_infer_model_path, \
quantize_model_path=g_quant_model_cache_path, \
sample_generator=g_quant_config.train_sample_generator, \
model_filename=g_quant_config.model_filename, \
params_filename=g_quant_config.params_filename, \
save_model_filename=g_quant_config.save_model_filename, \
save_params_filename=g_quant_config.save_params_filename, \
quantizable_op_type=g_quant_config.quantizable_op_type, \
activation_quantize_type="moving_average_abs_max", \
weight_quantize_type=g_quant_config.weight_quantize_type, \
algo=algo, \
hist_percent=hist_percent, \
bias_correction=bias_correct, \
batch_size=batch_size, \
batch_nums=batch_num)
global g_min_emd_loss
emd_loss = eval_quant_model()
if emd_loss < g_min_emd_loss:
g_min_emd_loss = emd_loss
if os.path.exists(g_quant_config.quantize_model_path):
shutil.rmtree(g_quant_config.quantize_model_path)
os.system("cp -r {0} {1}".format(g_quant_model_cache_path,
g_quant_config.quantize_model_path))
return emd_loss
def quant_post_hpo(executor,
place,
model_dir,
quantize_model_path,
train_sample_generator=None,
eval_sample_generator=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training",
runcount_limit=30):
"""
The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the
scale factor of quantized variables, and inserts fake quantization
and dequantization operators to obtain the quantized model.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents
the executor run on which device.
model_dir(str): The path of fp32 model that will be quantized, and
the model and params that saved by ``paddle.static.io.save_inference_model``
are under the path.
quantize_model_path(str): The path to save quantized model using api
``paddle.static.io.save_inference_model``.
train_sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
eval_sample_generator(Python Generator): The sample generator provides
evalution data for DataLoader, and it only returns a sample every time.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default: 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: '__model__'.
save_params_filename(str): The name of file to save all related parameters.
If it is set None, parameters will be saved in separate files. Default: '__params__'.
scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope().
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default: ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. Compared to 'abs_max',
the model accuracy is usually higher when using 'channel_wise_abs_max'.
optimize_model(bool, optional): If set optimize_model as True, it applies some
passes to optimize the model before quantization. So far, the place of
executor must be cpu it supports fusing batch_norm into convs.
is_use_cache_file(bool): This param is deprecated.
cache_dir(str): This param is deprecated.
runcount_limit(int): max. number of model quantization.
Returns:
None
"""
global g_quant_config
g_quant_config = QuantConfig(
executor, place, model_dir, quantize_model_path, train_sample_generator,
eval_sample_generator, model_filename, params_filename,
save_model_filename, save_params_filename, scope, quantizable_op_type,
is_full_quantize, weight_bits, activation_bits, weight_quantize_type,
optimize_model, is_use_cache_file, cache_dir)
cs = ConfigurationSpace()
algo = CategoricalHyperparameter(
"algo", ["KL", "hist", "avg", "mse"], default_value="KL")
bias_correct = CategoricalHyperparameter(
"bias_correct", [True, False], default_value=False)
weight_quantize_method = CategoricalHyperparameter("weight_quantize_method", \
[weight_quantize_type], default_value=weight_quantize_type)
hist_percent = UniformFloatHyperparameter(
"hist_percent", 0.98, 0.999, default_value=0.99)
batch_size = UniformIntegerHyperparameter(
"batch_size", 10, 30, default_value=10)
batch_num = UniformIntegerHyperparameter(
"batch_num", 10, 30, default_value=10)
cs.add_hyperparameters([algo, bias_correct, weight_quantize_method, \
hist_percent, batch_size, batch_num])
scenario = Scenario({
"run_obj": "quality", # we optimize quality (alternative runtime)
"runcount-limit":
runcount_limit, # max. number of function evaluations; for this example set to a low number
"cs": cs, # configuration space
"deterministic": "True",
"limit_resources": "False",
"memory_limit": 4096 # adapt this to reasonable value for your hardware
})
# To optimize, we pass the function to the SMAC-object
smac = SMAC4HPO(
scenario=scenario, rng=np.random.RandomState(42), tae_runner=quantize)
# Example call of the function with default values
# It returns: Status, Cost, Runtime, Additional Infos
def_value = smac.get_tae_runner().run(cs.get_default_configuration(), 1)[1]
print("Value for default configuration: %.8f" % def_value)
# Start optimization
try:
incumbent = smac.optimize()
finally:
incumbent = smac.solver.incumbent
inc_value = smac.get_tae_runner().run(incumbent, 1)[1]
print("Optimized Value: %.8f" % inc_value)
print("quantize completed")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
sys.path.append(".")
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import paddle
import paddle.dataset.mnist as reader
import unittest
from paddleslim.quant import quant_post_hpo
from static_case import StaticCase
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import numpy as np
class TestQuantPostHpoCase1(StaticCase):
def test_accuracy(self):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
def transform(x):
return np.reshape(x, [1, 28, 28])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
batch_size=64,
return_list=False)
valid_loader = paddle.io.DataLoader(
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def sample_generator_creator():
def __reader__():
for data in test_dataset:
image, label = data
image = np.expand_dims(image, axis=0)
label = np.expand_dims(label, axis=0)
yield image, label
return __reader__
def train(program):
iter = 0
for data in train_loader():
cost, top1, top5 = exe.run(
program,
feed=data,
fetch_list=[avg_cost, acc_top1, acc_top5])
iter += 1
if iter % 100 == 0:
print(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
def test(program, outputs=[avg_cost, acc_top1, acc_top5]):
iter = 0
result = [[], [], []]
for data in valid_loader():
cost, top1, top5 = exe.run(program,
feed=data,
fetch_list=outputs)
iter += 1
if iter % 100 == 0:
print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.fluid.io.save_inference_model(
dirname='./test_quant_post_hpo',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
main_program=val_prog,
executor=exe,
model_filename='model',
params_filename='params')
quant_post_hpo(
exe,
place,
"./test_quant_post_hpo",
"./test_quant_post_hpo_inference",
train_sample_generator=sample_generator_creator(),
eval_sample_generator=sample_generator_creator(),
model_filename="model",
params_filename="params",
save_model_filename='__model__',
save_params_filename='__params__',
runcount_limit=2)
quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_hpo_inference',
executor=exe,
model_filename='__model__',
params_filename='__params__')
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册