test_imperative_qat.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

import os
import numpy as np
import random
18
import time
19
import tempfile
20 21
import unittest
import logging
22

23 24 25 26 27
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
28
from paddle.nn import Sequential
29
from paddle.nn import Linear, Conv2D, Softmax, Conv2DTranspose
30
from paddle.fluid.log_helper import get_logger
31
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
32 33 34 35
from paddle.nn.quant.quant_layers import (
    QuantizedConv2D,
    QuantizedConv2DTranspose,
)
J
Jiabin Yang 已提交
36
from paddle.fluid.framework import _test_eager_guard
37 38
from imperative_test_utils import fix_model_dict, ImperativeLenet

P
pangyoki 已提交
39 40
paddle.enable_static()

41 42 43 44
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
    fluid.set_flags({"FLAGS_cudnn_deterministic": True})

45 46 47
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
48 49 50 51 52 53 54


class TestImperativeQat(unittest.TestCase):
    """
    QAT = quantization-aware training
    """

55
    def set_vars(self):
C
cc 已提交
56 57
        self.weight_quantize_type = 'abs_max'
        self.activation_quantize_type = 'moving_average_abs_max'
58 59
        self.onnx_format = False
        self.check_export_model_accuracy = True
60 61 62 63
        # The original model and quantized model may have different prediction.
        # There are 32 test data and we allow at most one is different.
        # Hence, the diff_threshold is 1 / 32 = 0.03125
        self.diff_threshold = 0.03125
64
        self.fuse_conv_bn = False
65

J
Jiabin Yang 已提交
66
    def func_qat(self):
67
        self.set_vars()
68

69
        imperative_qat = ImperativeQuantAware(
70
            weight_quantize_type=self.weight_quantize_type,
71
            activation_quantize_type=self.activation_quantize_type,
72
            fuse_conv_bn=self.fuse_conv_bn,
73 74
            onnx_format=self.onnx_format,
        )
75

76
        with fluid.dygraph.guard():
H
huangxu96 已提交
77
            # For CI coverage
78 79 80 81 82 83 84 85
            conv1 = Conv2D(
                in_channels=3,
                out_channels=2,
                kernel_size=3,
                stride=1,
                padding=1,
                padding_mode='replicate',
            )
H
huangxu96 已提交
86 87 88 89
            quant_conv1 = QuantizedConv2D(conv1)
            data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
            quant_conv1(fluid.dygraph.to_variable(data))

90 91
            conv_transpose = Conv2DTranspose(4, 6, (3, 3))
            quant_conv_transpose = QuantizedConv2DTranspose(conv_transpose)
92 93 94
            x_var = paddle.uniform(
                (2, 4, 8, 8), dtype='float32', min=-1.0, max=1.0
            )
95 96
            quant_conv_transpose(x_var)

97 98 99 100 101
            seed = 1
            np.random.seed(seed)
            fluid.default_main_program().random_seed = seed
            fluid.default_startup_program().random_seed = seed

102
            lenet = ImperativeLenet()
103
            lenet = fix_model_dict(lenet)
104
            imperative_qat.quantize(lenet)
105 106 107
            adam = AdamOptimizer(
                learning_rate=0.001, parameter_list=lenet.parameters()
            )
108

109 110 111 112 113 114
            train_reader = paddle.batch(
                paddle.dataset.mnist.train(), batch_size=32, drop_last=True
            )
            test_reader = paddle.batch(
                paddle.dataset.mnist.test(), batch_size=32
            )
115 116 117 118 119

            epoch_num = 1
            for epoch in range(epoch_num):
                lenet.train()
                for batch_id, data in enumerate(train_reader()):
120 121 122 123 124 125 126 127
                    x_data = np.array(
                        [x[0].reshape(1, 28, 28) for x in data]
                    ).astype('float32')
                    y_data = (
                        np.array([x[1] for x in data])
                        .astype('int64')
                        .reshape(-1, 1)
                    )
128 129 130 131

                    img = fluid.dygraph.to_variable(x_data)
                    label = fluid.dygraph.to_variable(y_data)
                    out = lenet(img)
132
                    acc = paddle.static.accuracy(out, label)
133 134 135
                    loss = paddle.nn.functional.cross_entropy(
                        out, label, reduction='none', use_softmax=False
                    )
136
                    avg_loss = paddle.mean(loss)
137 138 139 140 141
                    avg_loss.backward()
                    adam.minimize(avg_loss)
                    lenet.clear_gradients()
                    if batch_id % 100 == 0:
                        _logger.info(
142 143 144 145
                            "Train | At epoch {} step {}: loss = {:}, acc= {:}".format(
                                epoch, batch_id, avg_loss.numpy(), acc.numpy()
                            )
                        )
146 147
                    if batch_id == 500:  # For shortening CI time
                        break
148 149

                lenet.eval()
150
                eval_acc_top1_list = []
151
                for batch_id, data in enumerate(test_reader()):
152 153 154 155 156 157 158 159
                    x_data = np.array(
                        [x[0].reshape(1, 28, 28) for x in data]
                    ).astype('float32')
                    y_data = (
                        np.array([x[1] for x in data])
                        .astype('int64')
                        .reshape(-1, 1)
                    )
160 161 162 163 164

                    img = fluid.dygraph.to_variable(x_data)
                    label = fluid.dygraph.to_variable(y_data)

                    out = lenet(img)
165
                    acc_top1 = paddle.static.accuracy(
166 167
                        input=out, label=label, k=1
                    )
168
                    acc_top5 = paddle.static.accuracy(
169 170
                        input=out, label=label, k=5
                    )
171 172

                    if batch_id % 100 == 0:
173
                        eval_acc_top1_list.append(float(acc_top1.numpy()))
174
                        _logger.info(
175 176 177 178 179 180 181
                            "Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".format(
                                epoch,
                                batch_id,
                                acc_top1.numpy(),
                                acc_top5.numpy(),
                            )
                        )
182

183 184
                # check eval acc
                eval_acc_top1 = sum(eval_acc_top1_list) / len(
185 186
                    eval_acc_top1_list
                )
187
                print('eval_acc_top1', eval_acc_top1)
188 189 190 191
                self.assertTrue(
                    eval_acc_top1 > 0.9,
                    msg="The test acc {%f} is less than 0.9." % eval_acc_top1,
                )
192

193
            # test the correctness of `paddle.jit.save`
194
            data = next(test_reader())
195 196 197 198 199 200
            test_data = np.array(
                [x[0].reshape(1, 28, 28) for x in data]
            ).astype('float32')
            y_data = (
                np.array([x[1] for x in data]).astype('int64').reshape(-1, 1)
            )
201
            test_img = fluid.dygraph.to_variable(test_data)
202
            label = fluid.dygraph.to_variable(y_data)
203
            lenet.eval()
204
            fp32_out = lenet(test_img)
205
            fp32_acc = paddle.static.accuracy(fp32_out, label).numpy()
206

207 208 209 210 211 212
        with tempfile.TemporaryDirectory(prefix="qat_save_path_") as tmpdir:
            # save inference quantized model
            imperative_qat.save_quantized_model(
                layer=lenet,
                path=os.path.join(tmpdir, "lenet"),
                input_spec=[
213 214 215 216 217
                    paddle.static.InputSpec(
                        shape=[None, 1, 28, 28], dtype='float32'
                    )
                ],
            )
218 219 220 221 222 223 224
            print('Quantized model saved in %s' % tmpdir)

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
            else:
                place = core.CPUPlace()
            exe = fluid.Executor(place)
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
            [
                inference_program,
                feed_target_names,
                fetch_targets,
            ] = fluid.io.load_inference_model(
                dirname=tmpdir,
                executor=exe,
                model_filename="lenet" + INFER_MODEL_SUFFIX,
                params_filename="lenet" + INFER_PARAMS_SUFFIX,
            )
            (quant_out,) = exe.run(
                inference_program,
                feed={feed_target_names[0]: test_data},
                fetch_list=fetch_targets,
            )
240 241
            paddle.disable_static()
            quant_out = fluid.dygraph.to_variable(quant_out)
242
            quant_acc = paddle.static.accuracy(quant_out, label).numpy()
243 244
            paddle.enable_static()
            delta_value = fp32_acc - quant_acc
245
            self.assertLessEqual(delta_value, self.diff_threshold)
246

J
Jiabin Yang 已提交
247 248 249 250 251
    def test_qat(self):
        with _test_eager_guard():
            self.func_qat()
        self.func_qat()

252

253 254 255 256 257
class TestImperativeQatONNXFormat(unittest.TestCase):
    def set_vars(self):
        self.weight_quantize_type = 'abs_max'
        self.activation_quantize_type = 'moving_average_abs_max'
        self.onnx_format = True
258
        self.diff_threshold = 0.03125
259
        self.fuse_conv_bn = False
260 261


262 263
if __name__ == '__main__':
    unittest.main()