test_imperative_qat_amp.py 8.0 KB
Newer Older
1
#   copyright (c) 2022 paddlepaddle authors. all rights reserved.
C
cc 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

15
import logging
C
cc 已提交
16
import os
17
import tempfile
C
cc 已提交
18 19
import time
import unittest
20 21 22

import numpy as np
from imperative_test_utils import ImperativeLenet
C
cc 已提交
23 24 25 26

import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
27 28 29
from paddle.framework import set_flags
from paddle.quantization import ImperativeQuantAware
from paddle.static.log_helper import get_logger
C
cc 已提交
30 31 32

os.environ["CPU_NUM"] = "1"
if paddle.is_compiled_with_cuda():
33
    set_flags({"FLAGS_cudnn_deterministic": True})
C
cc 已提交
34

35 36 37
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
C
cc 已提交
38 39 40 41 42 43 44 45 46


class TestImperativeQatAmp(unittest.TestCase):
    """
    Test the combination of qat and amp.
    """

    @classmethod
    def setUpClass(cls):
47
        cls.root_path = tempfile.TemporaryDirectory(
48 49
            prefix="imperative_qat_amp_"
        )
50
        cls.save_path = os.path.join(cls.root_path.name, "model")
C
cc 已提交
51 52

        cls.download_path = 'dygraph_int8/download'
53 54 55
        cls.cache_folder = os.path.expanduser(
            '~/.cache/paddle/dataset/' + cls.download_path
        )
C
cc 已提交
56 57 58 59 60 61 62 63 64 65 66

        cls.lenet_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/lenet_pretrained.tar.gz"
        cls.lenet_md5 = "953b802fb73b52fae42896e3c24f0afb"

        seed = 1
        np.random.seed(seed)
        paddle.static.default_main_program().random_seed = seed
        paddle.static.default_startup_program().random_seed = seed

    @classmethod
    def tearDownClass(cls):
67
        cls.root_path.cleanup()
C
cc 已提交
68 69 70

    def cache_unzipping(self, target_folder, zip_path):
        if not os.path.exists(target_folder):
71
            cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
72 73
                target_folder, zip_path
            )
C
cc 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
            os.system(cmd)

    def download_model(self, data_url, data_md5, folder_name):
        download(data_url, self.download_path, data_md5)
        file_name = data_url.split('/')[-1]
        zip_path = os.path.join(self.cache_folder, file_name)
        print('Data is downloaded at {0}'.format(zip_path))

        data_cache_folder = os.path.join(self.cache_folder, folder_name)
        self.cache_unzipping(data_cache_folder, zip_path)
        return data_cache_folder

    def set_vars(self):
        self.qat = ImperativeQuantAware()

        self.train_batch_num = 30
        self.train_batch_size = 32
        self.test_batch_num = 100
        self.test_batch_size = 32
        self.eval_acc_top1 = 0.99

    def model_train(self, model, batch_num=-1, batch_size=32, use_amp=False):
        model.train()

98 99 100 101 102 103
        train_reader = paddle.batch(
            paddle.dataset.mnist.train(), batch_size=batch_size
        )
        adam = paddle.optimizer.Adam(
            learning_rate=0.001, parameters=model.parameters()
        )
C
cc 已提交
104 105 106
        scaler = paddle.amp.GradScaler(init_loss_scaling=500)

        for batch_id, data in enumerate(train_reader()):
107 108 109 110 111 112
            x_data = np.array([x[0].reshape(1, 28, 28) for x in data]).astype(
                'float32'
            )
            y_data = (
                np.array([x[1] for x in data]).astype('int64').reshape(-1, 1)
            )
C
cc 已提交
113 114 115 116 117 118 119

            img = paddle.to_tensor(x_data)
            label = paddle.to_tensor(y_data)

            if use_amp:
                with paddle.amp.auto_cast():
                    out = model(img)
120
                    acc = paddle.metric.accuracy(out, label)
121 122 123
                    loss = paddle.nn.functional.cross_entropy(
                        out, label, reduction='none', use_softmax=False
                    )
124
                    avg_loss = paddle.mean(loss)
C
cc 已提交
125 126 127 128 129 130 131
                scaled_loss = scaler.scale(avg_loss)
                scaled_loss.backward()

                scaler.minimize(adam, scaled_loss)
                adam.clear_gradients()
            else:
                out = model(img)
132
                acc = paddle.metric.accuracy(out, label)
133 134 135
                loss = paddle.nn.functional.cross_entropy(
                    out, label, reduction='none', use_softmax=False
                )
136
                avg_loss = paddle.mean(loss)
C
cc 已提交
137 138 139 140 141 142
                avg_loss.backward()

                adam.minimize(avg_loss)
                model.clear_gradients()

            if batch_id % 100 == 0:
143 144 145 146 147
                _logger.info(
                    "Train | step {}: loss = {:}, acc= {:}".format(
                        batch_id, avg_loss.numpy(), acc.numpy()
                    )
                )
C
cc 已提交
148 149 150 151 152 153 154

            if batch_num > 0 and batch_id + 1 >= batch_num:
                break

    def model_test(self, model, batch_num=-1, batch_size=32, use_amp=False):
        model.eval()

155 156 157
        test_reader = paddle.batch(
            paddle.dataset.mnist.test(), batch_size=batch_size
        )
C
cc 已提交
158 159 160

        acc_top1_list = []
        for batch_id, data in enumerate(test_reader()):
161 162 163 164 165 166
            x_data = np.array([x[0].reshape(1, 28, 28) for x in data]).astype(
                'float32'
            )
            y_data = (
                np.array([x[1] for x in data]).astype('int64').reshape(-1, 1)
            )
C
cc 已提交
167 168 169 170 171 172

            img = paddle.to_tensor(x_data)
            label = paddle.to_tensor(y_data)

            with paddle.amp.auto_cast(use_amp):
                out = model(img)
173 174
                acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
                acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
C
cc 已提交
175 176 177

            acc_top1_list.append(float(acc_top1.numpy()))
            if batch_id % 100 == 0:
178 179 180 181 182
                _logger.info(
                    "Test | At step {}: acc1 = {:}, acc5 = {:}".format(
                        batch_id, acc_top1.numpy(), acc_top5.numpy()
                    )
                )
C
cc 已提交
183 184 185 186 187 188 189

            if batch_num > 0 and batch_id + 1 >= batch_num:
                break

        acc_top1 = sum(acc_top1_list) / len(acc_top1_list)
        return acc_top1

190
    def test_ptq(self):
C
cc 已提交
191 192 193 194
        start_time = time.time()

        self.set_vars()

195 196 197
        params_path = self.download_model(
            self.lenet_url, self.lenet_md5, "lenet"
        )
C
cc 已提交
198 199 200 201 202 203 204 205
        params_path += "/lenet_pretrained/lenet.pdparams"

        with fluid.dygraph.guard():
            model = ImperativeLenet()
            model_state_dict = paddle.load(params_path)
            model.set_state_dict(model_state_dict)

            _logger.info("Test fp32 model")
206 207 208
            fp32_acc_top1 = self.model_test(
                model, self.test_batch_num, self.test_batch_size
            )
C
cc 已提交
209 210 211 212

            self.qat.quantize(model)

            use_amp = True
213 214 215
            self.model_train(
                model, self.train_batch_num, self.train_batch_size, use_amp
            )
C
cc 已提交
216 217

            _logger.info("Test int8 model")
218 219 220 221 222 223 224 225 226 227 228 229 230
            int8_acc_top1 = self.model_test(
                model, self.test_batch_num, self.test_batch_size, use_amp
            )

            _logger.info(
                'fp32_acc_top1: %f, int8_acc_top1: %f'
                % (fp32_acc_top1, int8_acc_top1)
            )
            self.assertTrue(
                int8_acc_top1 > fp32_acc_top1 - 0.01,
                msg='fp32_acc_top1: %f, int8_acc_top1: %f'
                % (fp32_acc_top1, int8_acc_top1),
            )
C
cc 已提交
231 232

        input_spec = [
233
            paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
C
cc 已提交
234 235 236 237 238 239 240 241 242 243
        ]
        paddle.jit.save(layer=model, path=self.save_path, input_spec=input_spec)
        print('Quantized model saved in {%s}' % self.save_path)

        end_time = time.time()
        print("total time: %ss" % (end_time - start_time))


if __name__ == '__main__':
    unittest.main()