未验证 提交 a5c2b5df 编写于 作者: L Liufang Sang 提交者: GitHub

add quant_post_only_weight (#200)

* add quant_post_only_weight

* update quant_post_weight

* change paddle version requirments

* add test for quant_post_only_weight
上级 517605a8
......@@ -20,8 +20,8 @@ from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
try:
fluid.require_version('1.7.0')
from .quanter import quant_aware, quant_post, convert
fluid.require_version('2.0.0')
from .quanter import quant_aware, quant_post, convert, quant_post_only_weight
except Exception as e:
_logger.warning(
"If you want to use training-aware and post-training quantization, "
......
......@@ -25,6 +25,7 @@ from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
......@@ -364,3 +365,59 @@ def convert(program, place, config=None, scope=None, save_int8=False):
return freezed_program, freezed_program_int8
else:
return freezed_program
def quant_post_only_weight(model_dir,
save_model_dir,
model_filename=None,
params_filename=None,
save_model_filename=None,
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
weight_bits=8,
generate_test_model=False):
'''
In order to reduce the size of model, this api quantizes the weight
of some ops from float32 to int8/16. In the inference stage, the
quantized weight will be dequantized to float32 again.
Args:
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
save_model_dir(str): The path to save the quantized model.
model_filename(str, optional): The name of file used to load the inference
program. If it is None, the default filename '__model__' will be used. Default is 'None'.
params_filename(str, optional): The name of file used to load all parameters. When all parameters were saved
in a single binary file, set it as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
save_model_dir(str): The path used to save the quantized model.
save_model_filename(str, optional): The name of file to
save the inference program. If it is None, the default
filename '__model__' will be used. Default is 'None'.
save_params_filename(str, optional): The name of file to
save all parameters. If it is None, parameters were
saved in separate files. If it is not None, all
parameters were saved in a single binary file.
quantizable_op_type(list[str], optional): The list of ops
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d", "depthwise_conv2d", "mul"].
weight_bits(int, optional): The bits for the quantized weight,
and it should be 8 or 16. Default is 8.
generate_test_model(bool, optional): If set generate_test_model
as True, it saves a fake quantized model, in which the weights
are quantized and dequantized. We can use PaddlePaddle to load
the fake quantized model and test the accuracy on GPU or CPU.
'''
weight_quant = WeightQuantization(
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename)
weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir,
save_model_filename=save_model_filename,
save_params_filename=save_params_filename,
quantizable_op_type=quantizable_op_type,
weight_bits=weight_bits,
generate_test_model=generate_test_model)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_post_only_weight
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
import numpy as np
class TestQuantPostOnlyWeightCase1(unittest.TestCase):
def test_accuracy(self):
image = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = fluid.default_main_program()
val_prog = main_prog.clone(for_test=True)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feeder = fluid.DataFeeder([image, label], place, program=main_prog)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=64)
eval_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64)
def train(program):
iter = 0
for data in train_reader():
cost, top1, top5 = exe.run(
program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
iter += 1
if iter % 100 == 0:
print(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
def test(program, outputs=[avg_cost, acc_top1, acc_top5]):
iter = 0
result = [[], [], []]
for data in train_reader():
cost, top1, top5 = exe.run(program,
feed=feeder.feed(data),
fetch_list=outputs)
iter += 1
if iter % 100 == 0:
print(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
train(main_prog)
top1_1, top5_1 = test(val_prog)
fluid.io.save_inference_model(
dirname='./test_quant_post',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
main_program=val_prog,
executor=exe,
model_filename='model',
params_filename='params')
quant_post_only_weight(
model_dir='./test_quant_post',
save_model_dir='./test_quant_post_inference',
model_filename='model',
params_filename='params',
generate_test_model=True)
quant_post_prog, feed_target_names, fetch_targets = fluid.io.load_inference_model(
dirname='./test_quant_post_inference/test_model', executor=exe)
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册