test_reconstruct_quantization.py 6.0 KB
Newer Older
G
gushiqiao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
17
import tempfile
G
gushiqiao 已提交
18 19 20 21
import paddle
from paddleslim.quant import quant_post_static
from static_case import StaticCase
sys.path.append("../demo")
22
from models import *
G
gushiqiao 已提交
23 24 25
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
import numpy as np
26 27 28
from paddleslim.quant import quant_recon_static


29
class ReconPTQ(unittest.TestCase):
G
gushiqiao 已提交
30
    def __init__(self, *args, **kwargs):
31
        super(ReconPTQ, self).__init__(*args, **kwargs)
G
gushiqiao 已提交
32
        paddle.enable_static()
33
        self.tmpdir = tempfile.TemporaryDirectory(prefix="test_")
G
gushiqiao 已提交
34
        self._gen_model()
35

G
gushiqiao 已提交
36 37 38 39
    def _gen_model(self):
        place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
        ) else paddle.CPUPlace()
        exe = paddle.static.Executor(place)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        main_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        with paddle.static.program_guard(main_program, startup_program):
            image = paddle.static.data(
                name='image', shape=[None, 1, 28, 28], dtype='float32')
            label = paddle.static.data(
                name='label', shape=[None, 1], dtype='int64')
            model = MobileNetV2()
            out = model.net(input=image, class_dim=10)
            cost = paddle.nn.functional.loss.cross_entropy(
                input=out, label=label)
            avg_cost = paddle.mean(x=cost)
            acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)

            val_program = main_program.clone(for_test=True)
            optimizer = paddle.optimizer.Momentum(
                momentum=0.9,
                learning_rate=0.01,
                weight_decay=paddle.regularizer.L2Decay(4e-5))
            optimizer.minimize(avg_cost)
        exe.run(startup_program)
62

G
gushiqiao 已提交
63 64
        def transform(x):
            return np.reshape(x, [1, 28, 28])
65

G
gushiqiao 已提交
66 67 68 69
        train_dataset = paddle.vision.datasets.MNIST(
            mode='train', backend='cv2', transform=transform)
        test_dataset = paddle.vision.datasets.MNIST(
            mode='test', backend='cv2', transform=transform)
70
        self.train_loader = paddle.io.DataLoader(
G
gushiqiao 已提交
71 72 73 74 75 76 77 78 79 80 81 82
            train_dataset,
            places=place,
            feed_list=[image, label],
            drop_last=True,
            batch_size=64,
            return_list=False)
        valid_loader = paddle.io.DataLoader(
            test_dataset,
            places=place,
            feed_list=[image, label],
            batch_size=64,
            return_list=False)
83

G
gushiqiao 已提交
84 85 86 87 88
        def sample_generator_creator():
            def __reader__():
                for data in test_dataset:
                    image, label = data
                    yield image, label
89

G
gushiqiao 已提交
90
            return __reader__
91

G
gushiqiao 已提交
92 93
        def train(program):
            iter = 0
94
            for data in self.train_loader():
G
gushiqiao 已提交
95 96 97 98 99 100 101 102 103
                cost, top1, top5 = exe.run(
                    program,
                    feed=data,
                    fetch_list=[avg_cost, acc_top1, acc_top5])
                iter += 1
                if iter % 100 == 0:
                    print(
                        'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
                        format(iter, cost, top1, top5))
104

105
        train(main_program)
G
gushiqiao 已提交
106
        paddle.fluid.io.save_inference_model(
107 108 109 110
            dirname=self.tmpdir.name,
            feeded_var_names=[image.name],
            target_vars=[out],
            main_program=val_program,
G
gushiqiao 已提交
111
            executor=exe,
112 113 114
            model_filename='model.pdmodel',
            params_filename='params.pdiparams')
        print(f"saved infer model to [{self.tmpdir.name}]")
115 116
        self.data_loader = sample_generator_creator()

117 118 119 120 121 122 123 124 125
    def __del__(self):
        self.tmpdir.cleanup()


class TestReconRegion(ReconPTQ):
    def __init__(self, *args, **kwargs):
        super(TestReconRegion, self).__init__(*args, **kwargs)

    def test_qdrop_region(self):
126 127 128 129 130
        place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
        ) else paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        quant_recon_static(
            exe,
131 132
            self.tmpdir.name,
            quantize_model_path='output_region',
133
            sample_generator=self.data_loader,
134 135 136 137
            model_filename='model.pdmodel',
            params_filename='params.pdiparams',
            batch_nums=1,
            epochs=1,
138
            algo='abs_max',
139 140
            regions=None,
            region_weights_names=None,
141 142 143
            recon_level='region-wise',
            simulate_activation_quant=True)

144 145 146 147 148 149

class TestReconLayer(ReconPTQ):
    def __init__(self, *args, **kwargs):
        super(TestReconLayer, self).__init__(*args, **kwargs)

    def test_qdrop_layer(self):
150 151 152 153 154
        place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
        ) else paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        quant_recon_static(
            exe,
155 156
            self.tmpdir.name,
            quantize_model_path='output_layer',
157
            sample_generator=self.data_loader,
158 159 160 161
            model_filename='model.pdmodel',
            params_filename='params.pdiparams',
            batch_nums=1,
            epochs=1,
162
            algo='KL',
163 164
            regions=None,
            region_weights_names=None,
165 166 167 168 169
            recon_level='layer-wise',
            simulate_activation_quant=True,
            bias_correction=True)


G
gushiqiao 已提交
170
if __name__ == '__main__':
171
    unittest.main()