未验证 提交 74511818 编写于 作者: H huangxu96 提交者: GitHub

Added dygraph quantization. (#435)

* Added dygraph quantization.

* Added save_quant_model function

* Overload quant_aware, make dynamic and static model use the same interface.

* Use singledispatch to overload the quant aware for dygraph quantization

* Add unnitest for dygraph quant_aware

* Use 2.0 api for supporting paddle 2.0

* Using pre-commit to modify the coding style.

* Add unittest for user inputted quant_config.

* Add singledispath into requirements.txt
Co-authored-by: NBai Yifan <me@ethanbai.com>
上级 4b3a7c99
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import copy import copy
import json import json
import logging import logging
from singledispatch import singledispatch
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -29,6 +30,8 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass ...@@ -29,6 +30,8 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization from paddle.fluid.contrib.slim.quantization import WeightQuantization
# For Imperative graph quantization
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from ..common import get_logger from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -79,7 +82,9 @@ _quant_config_default = { ...@@ -79,7 +82,9 @@ _quant_config_default = {
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False, 'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False 'is_full_quantize': False,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear']
} }
...@@ -171,9 +176,12 @@ def _parse_configs(user_config): ...@@ -171,9 +176,12 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \ assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9." "moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quantizable_layer_type'], list), \
"quantizable_layer_type must be a list"
return configs return configs
@singledispatch
def quant_aware(program, def quant_aware(program,
place, place,
config=None, config=None,
...@@ -300,6 +308,62 @@ def quant_aware(program, ...@@ -300,6 +308,62 @@ def quant_aware(program,
return quant_program return quant_program
@quant_aware.register(paddle.nn.Layer)
def _(model: paddle.nn.Layer,
config=None,
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
act_preprocess_func=None):
"""
This is function overload for dygraph model quant aware training.
Args:
model(nn.Layer)
config(dict, optional): configs for quantization. if None, will use default config.
Default: None.
weight_quantize_func(function): Function that defines how to quantize weight. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization function and dequantization function, that is, the function's input
is non-quantized weight and function returns dequantized weight. If None, will use
quantization op defined by 'weight_quantize_type'.
Default is None.
act_quantize_func(function): Function that defines how to quantize activation. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization and dequantization process, that is, the function's input
is non-quantized activation and function returns dequantized activation. If None, will use
quantization op defined by 'activation_quantize_type'.
Default is None.
weight_preprocess_func(function): Function that defines how to preprocess weight before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized weight and function returns processed weight to be quantized. If None, the weight will
be quantized directly.
Default is None.
act_preprocess_func(function): Function that defines how to preprocess activation before quantization. Using this
can quickly test if user's preprocess method works or not. The function's input
is non-quantized activation and function returns processed activation to be quantized. If None, the activation will
be quantized directly.
Default is None.
Returns:
model(nn.Layer) | nn.layer: model with fake quantized layers
"""
if config is None:
config = _quant_config_default
else:
assert isinstance(config, dict), "config must be dict"
config = _parse_configs(config)
imperative_qat = ImperativeQuantAware(
weight_quantize_type=config['weight_quantize_type'],
activation_quantize_type=config['activation_quantize_type'],
quantizable_layer_type=config['quantizable_layer_type'])
imperative_qat.quantize(model)
return model
def quant_post_static( def quant_post_static(
executor, executor,
model_dir, model_dir,
......
#paddlepaddle == 1.6.0rc0 #paddlepaddle == 1.6.0rc0
tqdm tqdm
pyzmq pyzmq
singledispatch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import sys
sys.path.append("../")
import unittest
import logging
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
from paddleslim.quant import quant_aware
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class ImperativeLenet(nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
self.features = Sequential(
Conv2D(
num_channels=1,
num_filters=6,
filter_size=3,
stride=1,
padding=1),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
Conv2D(
num_channels=6,
num_filters=16,
filter_size=5,
stride=1,
padding=0),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2))
self.fc = Sequential(
Linear(
input_dim=400, output_dim=120),
Linear(
input_dim=120, output_dim=84),
Linear(
input_dim=84, output_dim=num_classes,
act=classifier_activation))
def forward(self, inputs):
x = self.features(inputs)
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class TestImperativeQatDefaultConfig(unittest.TestCase):
"""
QAT = quantization-aware training
This test case uses defualt quantization config, weight_quantize_type
is channel_wise_abs_max
"""
def test_qat_acc(self):
with fluid.dygraph.guard():
lenet = ImperativeLenet()
quant_lenet = quant_aware(lenet)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32)
def train(model):
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = model(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
def test(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = model(img)
acc_top1 = fluid.layers.accuracy(
input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info(
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
train(lenet)
top1_1, top5_1 = test(lenet)
quant_lenet.__init__()
train(quant_lenet)
top1_2, top5_2 = test(quant_lenet)
# values before quantization and after quantization should be close
_logger.info("Before quantization: top1: {}, top5: {}".format(
top1_1, top5_1))
_logger.info("After quantization: top1: {}, top5: {}".format(
top1_2, top5_2))
class TestImperativeQatUserDefineConfig(unittest.TestCase):
"""
QAT = quantization-aware training
This test case is for testing user defined quantization config.
"""
def test_qat_acc(self):
with fluid.dygraph.guard():
lenet = ImperativeLenet()
quant_config = {
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantizable_layer_type': ['Conv2D', 'Linear']
}
quant_lenet = quant_aware(lenet, quant_config)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32)
def train(model):
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = model(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
def test(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = model(img)
acc_top1 = fluid.layers.accuracy(
input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info(
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
train(lenet)
top1_1, top5_1 = test(lenet)
quant_lenet.__init__()
train(quant_lenet)
top1_2, top5_2 = test(quant_lenet)
# values before quantization and after quantization should be close
_logger.info("Before quantization: top1: {}, top5: {}".format(
top1_1, top5_1))
_logger.info("After quantization: top1: {}, top5: {}".format(
top1_2, top5_2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册