未验证 提交 068119c9 编写于 作者: C cc 提交者: GitHub

[Quant] add dygraph ptq and refine dygraph qat (#838)

* add dygraph ptq and refine dygraph qat

* add unit test for ptq

* update according to comments
上级 09fba1bf
......@@ -12,5 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .quanter import QAT
__all__ = ['QAT']
from . import qat
from . import ptq
from .qat import *
from .ptq import *
__all__ = []
__all__ += qat.__all__
__all__ += ptq.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import paddle
import paddle.fluid.contrib.slim.quantization as Q
from paddle.fluid.contrib.slim.quantization import AbsmaxQuantizer
from paddle.fluid.contrib.slim.quantization import HistQuantizer
from paddle.fluid.contrib.slim.quantization import KLQuantizer
from paddle.fluid.contrib.slim.quantization import PerChannelAbsmaxQuantizer
from ...common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
'PTQ',
'AbsmaxQuantizer',
'HistQuantizer',
'KLQuantizer',
'PerChannelAbsmaxQuantizer',
]
class PTQ(object):
"""
Static post training quantization.
"""
def __init__(self,
activation_quantizer=Q.KLQuantizer(),
weight_quantizer=Q.PerChannelAbsmaxQuantizer()):
"""
Args:
activation_quantizer(Quantizer): The quantizer method for activation.
Default: KLQuantizer.
weight_quantizer(Quantizer): The quantizer method for weight.
Default: PerChannelAbsmaxQuantizer.
"""
assert isinstance(activation_quantizer, tuple(Q.SUPPORT_ACT_QUANTIZERS))
assert isinstance(weight_quantizer, tuple(Q.SUPPORT_WT_QUANTIZERS))
quant_config = Q.PTQConfig(
activation_quantizer=activation_quantizer,
weight_quantizer=weight_quantizer)
self.ptq = Q.ImperativePTQ(quant_config=quant_config)
def quantize(self, model, inplace=False):
"""
Quantize the input model.
Args:
model(paddle.nn.Layer): The model to be quantized.
inplace(bool): Whether apply quantization to the input model.
Default: False.
Returns:
quantized_model(paddle.nn.Layer): The quantized model.
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
return self.ptq.quantize(model=model, inplace=inplace)
def save_quantized_model(self, model, path, input_spec=None):
"""
Save the quantized inference model.
Args:
model (Layer): The model to be saved.
path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input
of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default: None.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
training = model.training
if training:
model.eval()
self.ptq.save_quantized_model(
model=model, path=path, input_spec=input_spec)
if training:
model.train()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -126,7 +126,7 @@ class PACT(paddle.nn.Layer):
super(PACT, self).__init__()
alpha_attr = paddle.ParamAttr(
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
initializer=paddle.nn.initializer.Constant(value=100),
learning_rate=1000.0)
self.alpha = self.create_parameter(
......@@ -207,11 +207,48 @@ class QAT(object):
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
def quantize(self, model):
def quantize(self, model, inplace=False):
"""
Quantize the input model.
Args:
model(paddle.nn.Layer): The model to be quantized.
inplace(bool): Whether apply quantization to the input model.
Default: False.
Returns:
quantized_model(paddle.nn.Layer): The quantized model.
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
self._model = copy.deepcopy(model)
self.imperative_qat.quantize(model)
if inplace:
self.imperative_qat.quantize(model)
quant_model = model
else:
quant_model = copy.deepcopy(model)
self.imperative_qat.quantize(quant_model)
return quant_model
def save_quantized_model(self, model, path, input_spec=None):
"""
Save the quantized inference model.
Args:
model (Layer): The model to be saved.
path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input
of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default: None.
Returns:
None
"""
if self.weight_preprocess is not None or self.act_preprocess is not None:
training = model.training
model = self._remove_preprocess(model)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import sys
sys.path.append("../../")
import unittest
import logging
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
from paddle.fluid.log_helper import get_logger
import paddle.vision.transforms as T
from paddleslim import PTQ
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class ImperativeLenet(nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
self.features = paddle.nn.Sequential(
paddle.nn.Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1),
paddle.nn.AvgPool2D(
kernel_size=2, stride=2),
paddle.nn.Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0),
paddle.nn.AvgPool2D(
kernel_size=2, stride=2))
self.fc = paddle.nn.Sequential(
paddle.nn.Linear(
in_features=400, out_features=120),
paddle.nn.Linear(
in_features=120, out_features=84),
paddle.nn.Linear(
in_features=84, out_features=num_classes), )
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
class TestPTQ(unittest.TestCase):
"""
Test dygraph post training quantization.
"""
def calibrate(self, model, test_reader, batch_num=10):
model.eval()
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
out = model(img)
if batch_num + 1 >= batch_num:
break
def model_test(self, model, test_reader):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info("Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
def model_train(self, model, train_reader):
adam = paddle.optimizer.Adam(
learning_rate=0.0001, parameters=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id, avg_loss.numpy(), acc.numpy()))
def test_ptq(self):
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
_logger.info("create the fp32 model")
fp32_lenet = ImperativeLenet()
_logger.info("prepare data")
batch_size = 64
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
place = paddle.CUDAPlace(0) \
if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
train_reader = paddle.io.DataLoader(
train_dataset,
drop_last=True,
places=place,
batch_size=batch_size,
return_list=True)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=batch_size, return_list=True)
_logger.info("train the fp32 model")
self.model_train(fp32_lenet, train_reader)
_logger.info("test fp32 model")
fp32_top1, fp32_top5 = self.model_test(fp32_lenet, test_reader)
_logger.info("quantize the fp32 model")
quanter = PTQ()
quant_lenet = quanter.quantize(fp32_lenet)
_logger.info("calibrate")
self.calibrate(quant_lenet, test_reader)
_logger.info("save and test the quantized model")
save_path = "./tmp/model"
input_spec = paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
quanter.save_quantized_model(
quant_lenet, save_path, input_spec=[input_spec])
quant_top1, quant_top5 = self.model_test(quant_lenet, test_reader)
_logger.info("FP32 acc: top1: {}, top5: {}".format(fp32_top1,
fp32_top5))
_logger.info("Int acc: top1: {}, top5: {}".format(quant_top1,
quant_top5))
diff = 0.002
self.assertTrue(
fp32_top1 - quant_top1 < diff,
msg="The acc of quant model is too lower than fp32 model")
if __name__ == '__main__':
unittest.main()
......@@ -66,20 +66,30 @@ class ImperativeLenet(nn.Layer):
return x
class TestImperativeQatDefaultConfig(unittest.TestCase):
class TestQAT(unittest.TestCase):
"""
QAT = quantization-aware training
This test case uses defualt quantization config, weight_quantize_type
is channel_wise_abs_max
"""
def set_seed(self):
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
def prepare(self):
self.quanter = QAT()
def test_qat_acc(self):
lenet = ImperativeLenet()
quanter = QAT()
quanter.quantize(lenet)
self.prepare()
self.set_seed()
fp32_lenet = ImperativeLenet()
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
place = paddle.CUDAPlace(0) \
if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
......@@ -141,130 +151,46 @@ class TestImperativeQatDefaultConfig(unittest.TestCase):
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
_logger.info("Test | Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
train(lenet)
top1_1, top5_1 = test(lenet)
train(fp32_lenet)
top1_1, top5_1 = test(fp32_lenet)
lenet.__init__()
train(lenet)
top1_2, top5_2 = test(lenet)
fp32_lenet.__init__()
quant_lenet = self.quanter.quantize(fp32_lenet)
train(quant_lenet)
top1_2, top5_2 = test(quant_lenet)
self.quanter.save_quantized_model(
quant_lenet,
'./tmp/qat',
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
# values before quantization and after quantization should be close
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("\n")
diff = 0.002
self.assertTrue(
top1_1 - top1_2 < diff,
msg="The acc of quant model is too lower than fp32 model")
class TestImperativeQatPACT(unittest.TestCase):
class TestQATWithPACT(TestQAT):
"""
QAT = quantization-aware training
This test case is for testing user defined quantization.
"""
def test_qat_acc(self):
lenet = ImperativeLenet()
quant_config = {
'activation_preprocess_type': 'PACT',
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
quanter = QAT(config=quant_config)
quanter.quantize(lenet)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset, drop_last=True, places=place, batch_size=64)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=64)
def train(model):
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
def test(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info(
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
train(lenet)
top1_1, top5_1 = test(lenet)
quanter.save_quantized_model(
lenet,
'./dygraph_qat',
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
lenet.__init__()
train(lenet)
top1_2, top5_2 = test(lenet)
# values before quantization and after quantization should be close
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
# test for saving model in train mode
lenet.train()
quanter.save_quantized_model(
lenet,
'./dygraph_qat',
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
def prepare(self):
quant_config = {'activation_preprocess_type': 'PACT', }
self.quanter = QAT(config=quant_config)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册