未验证 提交 a0e82c2b 编写于 作者: H huangxu96 提交者: GitHub

[Cherry-pick]Implemented AddQuantDequantPass in imperative quantization. (#26692) (#30525)

* Implemented AddQuantDequantPass in imperative quantization.

* support 2.0 API such as Pool2D and ReLU
上级 3688d9e9
......@@ -86,7 +86,7 @@ class ImperativeQuantAware(object):
'moving_average_abs_max', the static quantization scale will be calculated
during training and used in inference.
moving_rate(float): the parameter for 'moving_average_abs_max' quantization.
quantizable_op_type(list[str]): List the type of layers that will be quantized.
quantizable_layer_type(list[str]): List the type of layers that will be quantized.
Default is ['Conv2D', 'Linear']. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess
......@@ -229,7 +229,17 @@ class ImperativeQuantAware(object):
"'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now."
% (str(weight_quantize_type)))
self._quant_layers_map = {'Conv2D': Conv2D, 'Linear': Linear}
self._quant_layers_map = {
'Conv2D': Conv2D,
'Linear': Linear,
'Pool2D': Pool2D,
'ReLU': ReLU,
'LeakyReLU': LeakyReLU,
'ReLU6': ReLU6,
'Softmax': Softmax,
'Tanh': Tanh,
'Swish': Swish
}
self._quantizable_layer_type = tuple(
self._quant_layers_map[layer]
if layer in self._quant_layers_map else layer
......@@ -291,7 +301,12 @@ class ImperativeQuantAware(object):
layer.full_name()))
sys.exit(-1)
quantized_layer = quant_nn.__dict__[quantized_counterpart[index]](
layer_with_weight = ['QuantizedConv2D', 'QuantizedLinear']
if quantized_counterpart[index] not in layer_with_weight:
quant_layer_class_name = 'QuantizedNoweightLayer'
else:
quant_layer_class_name = quantized_counterpart[index]
quantized_layer = quant_nn.__dict__[quant_layer_class_name](
layer, self._weight_bits, self._activation_bits, self._moving_rate,
self._weight_quantize_type, self._activation_quantize_type,
self._weight_pre_layer, self._act_pre_layer,
......
......@@ -24,9 +24,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.nn import functional as F
__all__ = [
'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D',
'QuantizedLinear', 'FakeChannelWiseQuantDequantAbsMax',
'MovingAverageAbsMaxScale'
'FakeQuantMovingAverage', 'FakeQuantAbsMax',
'FakeChannelWiseQuantDequantAbsMax', 'QuantizedConv2D', 'QuantizedLinear',
'QuantizedNoweightLayer', 'MovingAverageAbsMaxScale'
]
......@@ -478,6 +478,30 @@ class QuantizedLinear(layers.Layer):
return out
class QuantizedNoweightLayer(layers.Layer):
def __init__(self,
layer,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
*args,
**kwargs):
super(QuantizedNoweightLayer, self).__init__()
self._layer = layer
self._fake_quant_input = _get_fake_quant_type(
'moving_average_abs_max',
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype,
quant_on_weight=False)
def forward(self, input):
quant_input = self._fake_quant_input(input)
return self._layer.forward(quant_input)
class MovingAverageAbsMaxScale(layers.Layer):
def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
r"""
......
......@@ -270,6 +270,30 @@ list(REMOVE_ITEM TEST_OPS
LIST(REMOVE_ITEM TEST_OPS test_auto_pruning)
LIST(REMOVE_ITEM TEST_OPS test_filter_pruning)
# only tests on singal GPU environment
LIST(REMOVE_ITEM TEST_OPS test_imperative_qat_addquantdequant)
py_test_modules(test_imperative_qat_addquantdequant MODULES test_imperative_qat_addquantdequant ENVS
CUDA_VISIBLE_DEVICES=0)
# fix
if(WIN32)
SET(SINGLE_CARD_TEST_OPS
test_user_defined_quantization
test_quantization_scale_pass
test_quantization_pass
test_moving_average_abs_max_scale_op
test_imperative_qat_channelwise
test_imperative_qat
test_imperative_out_scale
test_graph)
LIST(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS})
foreach(src ${SINGLE_CARD_TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS CUDA_VISIBLE_DEVICES=0)
endforeach()
endif()
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
......@@ -288,6 +312,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 120)
set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat_addquantdequant PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 120)
if(LINUX AND WITH_MKLDNN)
set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT 120)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import six
import paddle.fluid as fluid
from paddle.nn import functional
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm
from paddle.fluid.layers import nn
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware, QuantizationTransformPass, AddQuantDequantPass
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Pool2D
from paddle.nn.layer.activation import ReLU, LeakyReLU, ReLU6, Tanh, Swish
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def StaticLenet(data, num_classes=10):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
conv2d_w3_attr = fluid.ParamAttr(name="conv2d_w_3")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
conv2d_b3_attr = fluid.ParamAttr(name="conv2d_b_3")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv1 = fluid.layers.conv2d(
data,
num_filters=6,
filter_size=3,
stride=1,
padding=1,
param_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr)
conv1 = fluid.layers.leaky_relu(conv1, alpha=0.02)
pool1 = fluid.layers.pool2d(
conv1, pool_size=2, pool_type='max', pool_stride=2)
conv2 = fluid.layers.conv2d(
pool1,
num_filters=16,
filter_size=5,
stride=1,
padding=0,
param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr)
pool2 = fluid.layers.pool2d(
conv2, pool_size=2, pool_type='max', pool_stride=2)
pool2 = fluid.layers.relu(pool2)
pool2 = fluid.layers.swish(pool2)
conv3 = fluid.layers.conv2d(
pool2,
num_filters=16,
filter_size=1,
stride=1,
padding=0,
param_attr=conv2d_w3_attr,
bias_attr=conv2d_b3_attr)
conv3 = fluid.layers.relu6(conv3)
conv3 = paddle.tensor.math.tanh(conv3)
fc1 = fluid.layers.fc(input=conv3,
size=120,
param_attr=fc_w1_attr,
bias_attr=fc_b1_attr)
fc2 = fluid.layers.fc(input=fc1,
size=84,
param_attr=fc_w2_attr,
bias_attr=fc_b2_attr)
fc3 = fluid.layers.fc(input=fc2,
size=num_classes,
param_attr=fc_w3_attr,
bias_attr=fc_b3_attr)
fc3 = fluid.layers.softmax(fc3, use_cudnn=True)
return fc3
class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10):
super(ImperativeLenet, self).__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
conv2d_w3_attr = fluid.ParamAttr(name="conv2d_w_3")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
conv2d_b3_attr = fluid.ParamAttr(name="conv2d_b_3")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1,
weight_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr),
LeakyReLU(negative_slope=0.02),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
ReLU(),
Swish(),
Conv2D(
in_channels=16,
out_channels=16,
kernel_size=1,
stride=1,
padding=0,
weight_attr=conv2d_w3_attr,
bias_attr=conv2d_b3_attr),
ReLU6(),
Tanh())
self.fc = Sequential(
Linear(
in_features=400,
out_features=120,
weight_attr=fc_w1_attr,
bias_attr=fc_b1_attr),
Linear(
in_features=120,
out_features=84,
weight_attr=fc_w2_attr,
bias_attr=fc_b2_attr),
Linear(
in_features=84,
out_features=num_classes,
weight_attr=fc_w3_attr,
bias_attr=fc_b3_attr),
Softmax())
def forward(self, inputs):
x = self.features(inputs)
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class TestImperativeAddQuantDequant(unittest.TestCase):
def test_qat_save(self):
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
quantizable_layer_type=[
'Conv2D', 'Linear', 'ReLU', 'Pool2D', 'LeakyReLU', 'ReLU6',
'Tanh', 'Swish'
])
with fluid.dygraph.guard():
lenet = ImperativeLenet()
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=lenet.parameters())
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32)
epoch_num = 1
for epoch in range(epoch_num):
lenet.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
lenet.eval()
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc_top1 = fluid.layers.accuracy(
input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=out, label=label, k=5)
if batch_id % 100 == 0:
_logger.info(
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".
format(epoch, batch_id,
acc_top1.numpy(), acc_top5.numpy()))
# save weights
model_dict = lenet.state_dict()
fluid.save_dygraph(model_dict, "save_temp")
# test the correctness of `paddle.jit.save`
data = next(test_reader())
test_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
test_img = fluid.dygraph.to_variable(test_data)
lenet.eval()
before_save = lenet(test_img)
# save inference quantized model
path = "./qat_infer_model/lenet"
save_dir = "./qat_infer_model"
paddle.jit.save(
layer=lenet,
path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
dirname=save_dir,
executor=exe,
model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="lenet" + INFER_PARAMS_SUFFIX)
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
self.assertTrue(
np.allclose(after_save, before_save.numpy()),
msg='Failed to save the inference quantized model.')
def test_qat_acc(self):
def _build_static_lenet(main, startup, is_test=False, seed=1000):
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
main.random_seed = seed
startup.random_seed = seed
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
prediction = StaticLenet(img)
if not is_test:
loss = fluid.layers.cross_entropy(
input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
else:
avg_loss = prediction
return img, label, avg_loss
reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True)
weight_quantize_type = 'abs_max'
activation_quant_type = 'moving_average_abs_max'
param_init_map = {}
seed = 1000
lr = 0.001
# imperative train
_logger.info(
"--------------------------dynamic graph qat--------------------------"
)
imperative_qat = ImperativeQuantAware(
weight_quantize_type=weight_quantize_type,
activation_quantize_type=activation_quant_type,
quantizable_layer_type=[
'Conv2D', 'Linear', 'ReLU', 'LeakyReLU', 'ReLU6', 'Tanh',
'Swish'
])
with fluid.dygraph.guard():
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
lenet = ImperativeLenet()
fixed_state = {}
for name, param in lenet.named_parameters():
p_shape = param.numpy().shape
p_value = param.numpy()
if name.endswith("bias"):
value = np.zeros_like(p_value).astype('float32')
else:
value = np.random.normal(
loc=0.0, scale=0.01, size=np.product(p_shape)).reshape(
p_shape).astype('float32')
fixed_state[name] = value
param_init_map[param.name] = value
lenet.set_dict(fixed_state)
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters())
dynamic_loss_rec = []
lenet.train()
for batch_id, data in enumerate(reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
dynamic_loss_rec.append(avg_loss.numpy()[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', avg_loss.numpy()))
if batch_id > 500:
break
lenet.eval()
paddle.jit.save(
layer=lenet,
path="./dynamic_mnist/model",
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
# static graph train
_logger.info(
"--------------------------static graph qat--------------------------"
)
static_loss_rec = []
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
infer = fluid.Program()
startup = fluid.Program()
static_img, static_label, static_loss = _build_static_lenet(
main, startup, False, seed)
infer_img, _, infer_pre = _build_static_lenet(infer, startup, True,
seed)
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
opt = AdamOptimizer(learning_rate=lr)
opt.minimize(static_loss)
scope = core.Scope()
with fluid.scope_guard(scope):
exe.run(startup)
for param in main.all_parameters():
param_tensor = scope.var(param.name).get_tensor()
param_tensor.set(param_init_map[param.name], place)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
infer_graph = IrGraph(core.Graph(infer.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'])
add_quant_dequant_pass = AddQuantDequantPass(
scope=scope,
place=place,
quantizable_op_type=[
'relu', 'leaky_relu', 'relu6', 'tanh', 'swish'
])
transform_pass.apply(main_graph)
transform_pass.apply(infer_graph)
add_quant_dequant_pass.apply(main_graph)
add_quant_dequant_pass.apply(infer_graph)
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=static_loss.name, build_strategy=build_strategy)
feeder = fluid.DataFeeder(
feed_list=[static_img, static_label], place=place)
with fluid.scope_guard(scope):
for batch_id, data in enumerate(reader()):
loss_v, = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[static_loss])
static_loss_rec.append(loss_v[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', loss_v))
save_program = infer_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model("./static_mnist", [infer_img.name],
[infer_pre], exe, save_program)
rtol = 1e-08
atol = 1e-10
for i, (loss_d,
loss_s) in enumerate(zip(dynamic_loss_rec, static_loss_rec)):
diff = np.abs(loss_d - loss_s)
if diff > (atol + rtol * np.abs(loss_s)):
_logger.info(
"diff({}) at {}, dynamic loss = {}, static loss = {}".
format(diff, i, loss_d, loss_s))
break
self.assertTrue(
np.allclose(
np.array(dynamic_loss_rec),
np.array(static_loss_rec),
rtol=rtol,
atol=atol,
equal_nan=True),
msg='Failed to do the imperative qat.')
if __name__ == '__main__':
unittest.main()
......@@ -86,9 +86,9 @@ def StaticLenet(data, num_classes=10):
size=num_classes,
param_attr=fc_w3_attr,
bias_attr=fc_b3_attr)
fc4 = fluid.layers.softmax(fc3, use_cudnn=True)
fc3 = fluid.layers.softmax(fc3, use_cudnn=True)
return fc4
return fc3
class ImperativeLenet(fluid.dygraph.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册