test_imperative_qat.py 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

import os
import numpy as np
import random
18
import time
19
import tempfile
20 21
import unittest
import logging
22

23 24 25 26 27
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
28
from paddle.nn import Sequential
29
from paddle.nn import Linear, Conv2D, Softmax, Conv2DTranspose
30
from paddle.fluid.log_helper import get_logger
31
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
32 33 34 35
from paddle.nn.quant.quant_layers import (
    QuantizedConv2D,
    QuantizedConv2DTranspose,
)
36 37
from imperative_test_utils import fix_model_dict, ImperativeLenet

P
pangyoki 已提交
38 39
paddle.enable_static()

40 41 42 43
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
    fluid.set_flags({"FLAGS_cudnn_deterministic": True})

44 45 46
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
47 48 49 50 51 52 53


class TestImperativeQat(unittest.TestCase):
    """
    QAT = quantization-aware training
    """

54
    def set_vars(self):
C
cc 已提交
55 56
        self.weight_quantize_type = 'abs_max'
        self.activation_quantize_type = 'moving_average_abs_max'
57 58
        self.onnx_format = False
        self.check_export_model_accuracy = True
59 60 61 62
        # The original model and quantized model may have different prediction.
        # There are 32 test data and we allow at most one is different.
        # Hence, the diff_threshold is 1 / 32 = 0.03125
        self.diff_threshold = 0.03125
63
        self.fuse_conv_bn = False
64

65
    def test_qat(self):
66
        self.set_vars()
67

68
        imperative_qat = ImperativeQuantAware(
69
            weight_quantize_type=self.weight_quantize_type,
70
            activation_quantize_type=self.activation_quantize_type,
71
            fuse_conv_bn=self.fuse_conv_bn,
72 73
            onnx_format=self.onnx_format,
        )
74

75
        with fluid.dygraph.guard():
H
huangxu96 已提交
76
            # For CI coverage
77 78 79 80 81 82 83 84
            conv1 = Conv2D(
                in_channels=3,
                out_channels=2,
                kernel_size=3,
                stride=1,
                padding=1,
                padding_mode='replicate',
            )
H
huangxu96 已提交
85 86 87 88
            quant_conv1 = QuantizedConv2D(conv1)
            data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
            quant_conv1(fluid.dygraph.to_variable(data))

89 90
            conv_transpose = Conv2DTranspose(4, 6, (3, 3))
            quant_conv_transpose = QuantizedConv2DTranspose(conv_transpose)
91 92 93
            x_var = paddle.uniform(
                (2, 4, 8, 8), dtype='float32', min=-1.0, max=1.0
            )
94 95
            quant_conv_transpose(x_var)

96 97 98 99 100
            seed = 1
            np.random.seed(seed)
            fluid.default_main_program().random_seed = seed
            fluid.default_startup_program().random_seed = seed

101
            lenet = ImperativeLenet()
102
            lenet = fix_model_dict(lenet)
103
            imperative_qat.quantize(lenet)
104 105 106
            adam = AdamOptimizer(
                learning_rate=0.001, parameter_list=lenet.parameters()
            )
107

108 109 110 111 112 113
            train_reader = paddle.batch(
                paddle.dataset.mnist.train(), batch_size=32, drop_last=True
            )
            test_reader = paddle.batch(
                paddle.dataset.mnist.test(), batch_size=32
            )
114 115 116 117 118

            epoch_num = 1
            for epoch in range(epoch_num):
                lenet.train()
                for batch_id, data in enumerate(train_reader()):
119 120 121 122 123 124 125 126
                    x_data = np.array(
                        [x[0].reshape(1, 28, 28) for x in data]
                    ).astype('float32')
                    y_data = (
                        np.array([x[1] for x in data])
                        .astype('int64')
                        .reshape(-1, 1)
                    )
127 128 129 130

                    img = fluid.dygraph.to_variable(x_data)
                    label = fluid.dygraph.to_variable(y_data)
                    out = lenet(img)
131
                    acc = paddle.static.accuracy(out, label)
132 133 134
                    loss = paddle.nn.functional.cross_entropy(
                        out, label, reduction='none', use_softmax=False
                    )
135
                    avg_loss = paddle.mean(loss)
136 137 138 139 140
                    avg_loss.backward()
                    adam.minimize(avg_loss)
                    lenet.clear_gradients()
                    if batch_id % 100 == 0:
                        _logger.info(
141 142 143 144
                            "Train | At epoch {} step {}: loss = {:}, acc= {:}".format(
                                epoch, batch_id, avg_loss.numpy(), acc.numpy()
                            )
                        )
145 146
                    if batch_id == 500:  # For shortening CI time
                        break
147 148

                lenet.eval()
149
                eval_acc_top1_list = []
150
                for batch_id, data in enumerate(test_reader()):
151 152 153 154 155 156 157 158
                    x_data = np.array(
                        [x[0].reshape(1, 28, 28) for x in data]
                    ).astype('float32')
                    y_data = (
                        np.array([x[1] for x in data])
                        .astype('int64')
                        .reshape(-1, 1)
                    )
159 160 161 162 163

                    img = fluid.dygraph.to_variable(x_data)
                    label = fluid.dygraph.to_variable(y_data)

                    out = lenet(img)
164
                    acc_top1 = paddle.static.accuracy(
165 166
                        input=out, label=label, k=1
                    )
167
                    acc_top5 = paddle.static.accuracy(
168 169
                        input=out, label=label, k=5
                    )
170 171

                    if batch_id % 100 == 0:
172
                        eval_acc_top1_list.append(float(acc_top1.numpy()))
173
                        _logger.info(
174 175 176 177 178 179 180
                            "Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".format(
                                epoch,
                                batch_id,
                                acc_top1.numpy(),
                                acc_top5.numpy(),
                            )
                        )
181

182 183
                # check eval acc
                eval_acc_top1 = sum(eval_acc_top1_list) / len(
184 185
                    eval_acc_top1_list
                )
186
                print('eval_acc_top1', eval_acc_top1)
187 188 189 190
                self.assertTrue(
                    eval_acc_top1 > 0.9,
                    msg="The test acc {%f} is less than 0.9." % eval_acc_top1,
                )
191

192
            # test the correctness of `paddle.jit.save`
193
            data = next(test_reader())
194 195 196 197 198 199
            test_data = np.array(
                [x[0].reshape(1, 28, 28) for x in data]
            ).astype('float32')
            y_data = (
                np.array([x[1] for x in data]).astype('int64').reshape(-1, 1)
            )
200
            test_img = fluid.dygraph.to_variable(test_data)
201
            label = fluid.dygraph.to_variable(y_data)
202
            lenet.eval()
203
            fp32_out = lenet(test_img)
204
            fp32_acc = paddle.static.accuracy(fp32_out, label).numpy()
205

206 207 208 209 210 211
        with tempfile.TemporaryDirectory(prefix="qat_save_path_") as tmpdir:
            # save inference quantized model
            imperative_qat.save_quantized_model(
                layer=lenet,
                path=os.path.join(tmpdir, "lenet"),
                input_spec=[
212 213 214 215 216
                    paddle.static.InputSpec(
                        shape=[None, 1, 28, 28], dtype='float32'
                    )
                ],
            )
217 218 219 220 221 222 223
            print('Quantized model saved in %s' % tmpdir)

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(0)
            else:
                place = core.CPUPlace()
            exe = fluid.Executor(place)
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
            [
                inference_program,
                feed_target_names,
                fetch_targets,
            ] = fluid.io.load_inference_model(
                dirname=tmpdir,
                executor=exe,
                model_filename="lenet" + INFER_MODEL_SUFFIX,
                params_filename="lenet" + INFER_PARAMS_SUFFIX,
            )
            (quant_out,) = exe.run(
                inference_program,
                feed={feed_target_names[0]: test_data},
                fetch_list=fetch_targets,
            )
239 240
            paddle.disable_static()
            quant_out = fluid.dygraph.to_variable(quant_out)
241
            quant_acc = paddle.static.accuracy(quant_out, label).numpy()
242 243
            paddle.enable_static()
            delta_value = fp32_acc - quant_acc
244
            self.assertLessEqual(delta_value, self.diff_threshold)
245

246

247 248 249 250 251
class TestImperativeQatONNXFormat(unittest.TestCase):
    def set_vars(self):
        self.weight_quantize_type = 'abs_max'
        self.activation_quantize_type = 'moving_average_abs_max'
        self.onnx_format = True
252
        self.diff_threshold = 0.03125
253
        self.fuse_conv_bn = False
254 255


256 257
if __name__ == '__main__':
    unittest.main()