未验证 提交 c363c91c 编写于 作者: B Bai Yifan 提交者: GitHub

Add dygraph quant aware training (#530)

上级 3718a146
......@@ -19,3 +19,6 @@ __all__ += l2norm_pruner.__all__
__all__ += fpgm_pruner.__all__
__all__ += pruner.__all__
__all__ += filter_pruner.__all__
from .quant import *
__all__ += quant.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .quanter import QAT
__all__ = ['QAT']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import paddle
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from ...common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
]
ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max'
]
BUILT_IN_PREPROCESS_TYPES = ['PACT']
VALID_DTYPES = ['int8']
DYGRAPH_QUANTIZABLE_TYPE = ['Conv2D', 'Linear']
__all__ = ['QAT']
_quant_config_default = {
# weight preprocess type, default is None and no preprocessing is performed.
'weight_preprocess_type': None,
# activation preprocess type, default is None and no preprocessing is performed.
'activation_preprocess_type': None,
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. default is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
def _parse_configs(user_config):
"""
check if user's configs are valid.
Args:
user_config(dict): user's config.
Return:
configs(dict): final configs will be used.
"""
configs = copy.deepcopy(_quant_config_default)
configs.update(user_config)
# check if configs is valid
weight_types = WEIGHT_QUANTIZATION_TYPES
activation_types = WEIGHT_QUANTIZATION_TYPES
assert configs['weight_preprocess_type'] in BUILT_IN_PREPROCESS_TYPES or configs['weight_preprocess_type'] is None, \
"Unknown weight_preprocess_type: {}. only supports {} ".format(configs['weight_preprocess_type'],
BUILT_IN_PREPROCESS_TYPES)
assert configs['activation_preprocess_type'] in BUILT_IN_PREPROCESS_TYPES or configs['activation_preprocess_type'] is None, \
"Unknown activation_preprocess_type: {}. only supports {}".format(configs['activation_preprocess_type'],
BUILT_IN_PREPROCESS_TYPES)
assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \
"Unknown weight_quantize_type: {}. only supports {} ".format(configs['weight_quantize_type'],
WEIGHT_QUANTIZATION_TYPES)
assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \
"Unknown activation_quantize_type: {}. only supports {}".format(configs['activation_quantize_type'],
ACTIVATION_QUANTIZATION_TYPES)
assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value."
assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
"weight_bits should be between 1 and 16."
assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value."
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['dtype'], str), \
"dtype must be a str."
assert (configs['dtype'] in VALID_DTYPES), \
"dtype can only be " + " ".join(VALID_DTYPES)
assert isinstance(configs['window_size'], int), \
"window_size must be int value, window size for 'range_abs_max' quantization, default is 10000."
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quantizable_layer_type'], list), \
"quantizable_layer_type must be a list"
for op_type in configs['quantizable_layer_type']:
assert (op_type in DYGRAPH_QUANTIZABLE_TYPE), "{} is not support, \
now support op types are {}".format(
op_type, DYGRAPH_QUANTIZABLE_TYPE)
return configs
class PACT(paddle.nn.Layer):
def __init__(self):
super(PACT, self).__init__()
alpha_attr = paddle.ParamAttr(
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
learning_rate=10.0)
self.alpha = self.create_parameter(
shape=[1], attr=alpha_attr, dtype='float32')
def forward(self, x):
out_left = paddle.nn.functional.relu(x - self.alpha)
out_right = paddle.nn.functional.relu(-self.alpha - x)
x = x - out_left + out_right
return x
class QAT(object):
"""
Quant Aware Training(QAT): Add the fake quant logic for given quantizable layers, namely add the quant_dequant computational logic both for activation inputs and weight inputs.
"""
def __init__(self,
config=None,
weight_preprocess=None,
act_preprocess=None,
weight_quantize=None,
act_quantize=None):
"""
Args:
model(nn.Layer)
config(dict, optional): configs for quantization. if None, will use default config.
Default: None.
weight_quantize(class, optional): Defines how to quantize weight. Using this
can quickly test if user's quantization method works or not. In this method, user should
both define quantization function and dequantization function, that is, the function's input
is non-quantized weight and function returns dequantized weight. If None, will use
quantization op defined by 'weight_quantize_type'.
Default is None.
act_quantize(class, optional): Defines how to quantize activation. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization and dequantization process, that is, the function's input
is non-quantized activation and function returns dequantized activation. If None, will use
quantization op defined by 'activation_quantize_type'.
Default is None.
weight_preprocess(class, optional): Defines how to preprocess weight before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized weight and function returns processed weight to be quantized. If None, will
use preprocess method defined by 'weight_preprocess_type'.
Default is None.
act_preprocess(class, optional): Defines how to preprocess activation before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized activation and function returns processed activation to be quantized. If None,
will use preprocess method defined by 'activation_preprocess_type'.
Default is None.
"""
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
self.config = config
self.weight_preprocess = PACT if self.config[
'weight_preprocess_type'] == 'PACT' else None
self.act_preprocess = PACT if self.config[
'activation_preprocess_type'] == 'PACT' else None
self.weight_preprocess = weight_preprocess if weight_preprocess is not None else self.weight_preprocess
self.act_preprocess = act_preprocess if act_preprocess is not None else self.act_preprocess
self.weight_quantize = weight_quantize
self.act_quantize = act_quantize
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config['activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
def quantize(self, model):
self.imperative_qat.quantize(model)
def save_quantized_model(self, model, path, input_spec=None):
if self.weight_preprocess is not None or self.act_preprocess is not None:
model = self._remove_preprocess(model)
self.imperative_qat.save_quantized_model(
layer=model, path=path, input_spec=input_spec)
def _remove_preprocess(self, model):
state_dict = model.state_dict()
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config['activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'])
with paddle.utils.unique_name.guard():
model.__init__()
self.imperative_qat.quantize(model)
state_dict = model.state_dict()
model.set_state_dict(state_dict)
return model
......@@ -16,7 +16,6 @@ import os
import copy
import json
import logging
from singledispatch import singledispatch
import paddle
from paddle.fluid.framework import IrGraph
......@@ -30,8 +29,6 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
# For Imperative graph quantization
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
......@@ -82,9 +79,7 @@ _quant_config_default = {
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear']
'is_full_quantize': False
}
......@@ -176,12 +171,9 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quantizable_layer_type'], list), \
"quantizable_layer_type must be a list"
return configs
@singledispatch
def quant_aware(program,
place,
config=None,
......@@ -308,62 +300,6 @@ def quant_aware(program,
return quant_program
@quant_aware.register(paddle.nn.Layer)
def _(model: paddle.nn.Layer,
config=None,
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
act_preprocess_func=None):
"""
This is function overload for dygraph model quant aware training.
Args:
model(nn.Layer)
config(dict, optional): configs for quantization. if None, will use default config.
Default: None.
weight_quantize_func(function): Function that defines how to quantize weight. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization function and dequantization function, that is, the function's input
is non-quantized weight and function returns dequantized weight. If None, will use
quantization op defined by 'weight_quantize_type'.
Default is None.
act_quantize_func(function): Function that defines how to quantize activation. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization and dequantization process, that is, the function's input
is non-quantized activation and function returns dequantized activation. If None, will use
quantization op defined by 'activation_quantize_type'.
Default is None.
weight_preprocess_func(function): Function that defines how to preprocess weight before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized weight and function returns processed weight to be quantized. If None, the weight will
be quantized directly.
Default is None.
act_preprocess_func(function): Function that defines how to preprocess activation before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized activation and function returns processed activation to be quantized. If None, the activation will
be quantized directly.
Default is None.
Returns:
model(nn.Layer) | nn.layer: model with fake quantized layers
"""
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
imperative_qat = ImperativeQuantAware(
weight_quantize_type=config['weight_quantize_type'],
activation_quantize_type=config['activation_quantize_type'],
quantizable_layer_type=config['quantizable_layer_type'])
imperative_qat.quantize(model)
return model
def quant_post_static(
executor,
model_dir,
......
......@@ -14,18 +14,16 @@
import numpy as np
import sys
sys.path.append("../")
sys.path.append("../../")
import unittest
import logging
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
import paddle.vision.transforms as T
from paddleslim.quant import quant_aware
from paddleslim.dygraph.quant import QAT
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
......@@ -35,31 +33,30 @@ class ImperativeLenet(nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
self.features = paddle.nn.Sequential(
Conv2D(
num_channels=1,
num_filters=6,
filter_size=3,
paddle.nn.Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1),
Pool2D(
paddle.nn.Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
Conv2D(
num_channels=6,
num_filters=16,
filter_size=5,
paddle.nn.Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0),
Pool2D(
paddle.nn.Pool2D(
pool_size=2, pool_type='max', pool_stride=2))
self.fc = paddle.nn.Sequential(
Linear(
input_dim=400, output_dim=120),
Linear(
input_dim=120, output_dim=84),
Linear(
input_dim=84, output_dim=num_classes,
act=classifier_activation))
paddle.nn.Linear(
in_features=400, out_features=120),
paddle.nn.Linear(
in_features=120, out_features=84),
paddle.nn.Linear(
in_features=84, out_features=num_classes), )
def forward(self, inputs):
x = self.features(inputs)
......@@ -78,32 +75,40 @@ class TestImperativeQatDefaultConfig(unittest.TestCase):
def test_qat_acc(self):
lenet = ImperativeLenet()
quant_lenet = quant_aware(lenet)
quanter = QAT()
quanter.quantize(lenet)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
def transform(x):
return np.reshape(x, [1, 28, 28])
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset, drop_last=True, places=place, batch_size=64)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset,
drop_last=True,
places=place,
batch_size=64,
return_list=True)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=64)
val_dataset, places=place, batch_size=64, return_list=True)
def train(model):
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
learning_rate=0.0001, parameters=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
......@@ -122,7 +127,9 @@ class TestImperativeQatDefaultConfig(unittest.TestCase):
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
......@@ -141,44 +148,43 @@ class TestImperativeQatDefaultConfig(unittest.TestCase):
train(lenet)
top1_1, top5_1 = test(lenet)
quant_lenet.__init__()
train(quant_lenet)
top1_2, top5_2 = test(quant_lenet)
lenet.__init__()
train(lenet)
top1_2, top5_2 = test(lenet)
# values before quantization and after quantization should be close
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
class TestImperativeQatUserDefineConfig(unittest.TestCase):
class TestImperativeQatPACT(unittest.TestCase):
"""
QAT = quantization-aware training
This test case is for testing user defined quantization config.
This test case is for testing user defined quantization.
"""
def test_qat_acc(self):
lenet = ImperativeLenet()
quant_config = {
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantizable_layer_type': ['Conv2D', 'Linear']
'activation_preprocess_type': 'PACT',
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
quant_lenet = quant_aware(lenet, quant_config)
quanter = QAT(config=quant_config)
quanter.quantize(lenet)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
def transform(x):
return np.reshape(x, [1, 28, 28])
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset, drop_last=True, places=place, batch_size=64)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset, drop_last=True, places=place, batch_size=64)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=64)
......@@ -191,6 +197,9 @@ class TestImperativeQatUserDefineConfig(unittest.TestCase):
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
......@@ -210,6 +219,8 @@ class TestImperativeQatUserDefineConfig(unittest.TestCase):
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
......@@ -227,16 +238,23 @@ class TestImperativeQatUserDefineConfig(unittest.TestCase):
train(lenet)
top1_1, top5_1 = test(lenet)
quant_lenet.__init__()
train(quant_lenet)
top1_2, top5_2 = test(quant_lenet)
quanter.save_quantized_model(
lenet,
'./dygraph_qat',
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
lenet.__init__()
train(lenet)
top1_2, top5_2 = test(lenet)
# values before quantization and after quantization should be close
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("Before quantization: top1: {}, top5: {}".format(top1_2,
top5_2))
_logger.info("After quantization: top1: {}, top5: {}".format(top1_1,
top5_1))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册