test_quant_aware.py 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import sys
16
from typing import List
17 18 19 20
sys.path.append("../")
import unittest
import paddle
from paddleslim.quant import quant_aware, convert
21
from static_case import StaticCase
22 23 24 25 26 27
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import numpy as np


28
class TestQuantAwareCase(StaticCase):
29
    def test_accuracy(self):
B
Bai Yifan 已提交
30 31 32
        image = paddle.static.data(
            name='image', shape=[None, 1, 28, 28], dtype='float32')
        label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
33 34
        model = MobileNet()
        out = model.net(input=image, class_dim=10)
B
Bai Yifan 已提交
35 36 37 38 39
        cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
        avg_cost = paddle.mean(x=cost)
        acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
        acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
        optimizer = paddle.optimizer.Momentum(
40 41
            momentum=0.9,
            learning_rate=0.01,
B
Bai Yifan 已提交
42
            weight_decay=paddle.regularizer.L2Decay(4e-5))
43
        optimizer.minimize(avg_cost)
B
Bai Yifan 已提交
44
        main_prog = paddle.static.default_main_program()
45
        val_prog = paddle.static.default_main_program().clone(for_test=True)
46

47 48
        place = paddle.CUDAPlace(
            0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
B
Bai Yifan 已提交
49 50
        exe = paddle.static.Executor(place)
        exe.run(paddle.static.default_startup_program())
B
Bai Yifan 已提交
51 52 53 54 55 56 57 58 59 60 61

        def transform(x):
            return np.reshape(x, [1, 28, 28])

        train_dataset = paddle.vision.datasets.MNIST(
            mode='train', backend='cv2', transform=transform)
        test_dataset = paddle.vision.datasets.MNIST(
            mode='test', backend='cv2', transform=transform)
        train_loader = paddle.io.DataLoader(
            train_dataset,
            places=place,
B
Bai Yifan 已提交
62
            feed_list=[image, label],
B
Bai Yifan 已提交
63
            drop_last=True,
64
            return_list=False,
B
Bai Yifan 已提交
65 66
            batch_size=64)
        valid_loader = paddle.io.DataLoader(
67 68 69 70 71
            test_dataset,
            places=place,
            feed_list=[image, label],
            batch_size=64,
            return_list=False)
72 73 74

        def train(program):
            iter = 0
B
Bai Yifan 已提交
75
            for data in train_loader():
76 77
                cost, top1, top5 = exe.run(
                    program,
B
Bai Yifan 已提交
78
                    feed=data,
79 80 81 82 83 84 85 86 87 88
                    fetch_list=[avg_cost, acc_top1, acc_top5])
                iter += 1
                if iter % 100 == 0:
                    print(
                        'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
                        format(iter, cost, top1, top5))

        def test(program):
            iter = 0
            result = [[], [], []]
B
Bai Yifan 已提交
89
            for data in valid_loader():
90 91
                cost, top1, top5 = exe.run(
                    program,
B
Bai Yifan 已提交
92
                    feed=data,
93 94 95
                    fetch_list=[avg_cost, acc_top1, acc_top5])
                iter += 1
                if iter % 100 == 0:
96 97
                    print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
                          format(iter, cost, top1, top5))
98 99 100 101 102 103 104 105 106 107
                result[0].append(cost)
                result[1].append(top1)
                result[2].append(top5)
            print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
                np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
            return np.mean(result[1]), np.mean(result[2])

        train(main_prog)
        top1_1, top5_1 = test(main_prog)

108 109 110 111 112 113 114 115 116
        ops_with_weights = [
            'depthwise_conv2d',
            'mul',
            'conv2d',
        ]
        ops_without_weights = [
            'relu',
        ]

117 118 119
        config = {
            'weight_quantize_type': 'channel_wise_abs_max',
            'activation_quantize_type': 'moving_average_abs_max',
120
            'quantize_op_types': ops_with_weights + ops_without_weights,
121
        }
122
        quant_train_prog = quant_aware(main_prog, place, config, for_test=False)
123
        quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

        # Step1: check the quantizers count in qat graph
        quantizers_count_in_qat = self.count_op(quant_eval_prog,
                                                ['quantize_linear'])
        ops_with_weights_count = self.count_op(quant_eval_prog,
                                               ops_with_weights)
        ops_without_weights_count = self.count_op(quant_eval_prog,
                                                  ops_without_weights)
        self.assertEqual(ops_with_weights_count * 2 + ops_without_weights_count,
                         quantizers_count_in_qat)

        with paddle.static.program_guard(quant_eval_prog):
            paddle.static.save_inference_model("./models/mobilenet_qat", [
                image, label
            ], [avg_cost, acc_top1, acc_top5], exe)
139

140
        train(quant_train_prog)
141 142
        convert_eval_prog = convert(quant_eval_prog, place, config)

143 144 145 146 147
        with paddle.static.program_guard(convert_eval_prog):
            paddle.static.save_inference_model("./models/mobilenet_onnx", [
                image, label
            ], [avg_cost, acc_top1, acc_top5], exe)

148
        top1_2, top5_2 = test(convert_eval_prog)
149 150 151 152
        # values before quantization and after quantization should be close
        print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
        print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))

153 154 155 156 157 158
        # Step2: check the quantizers count in onnx graph
        quantizers_count = self.count_op(convert_eval_prog, ['quantize_linear'])
        observers_count = self.count_op(quant_eval_prog,
                                        ['moving_average_abs_max_scale'])
        self.assertEqual(quantizers_count, ops_with_weights_count +
                         ops_without_weights_count + observers_count)
159

160
        # Step3: check the quantization skipping
161
        config['not_quant_pattern'] = ['last_fc']
162
        skip_quant_prog = quant_aware(
163
            main_prog, place, config=config, for_test=True)
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        skip_quantizers_count_in_qat = self.count_op(skip_quant_prog,
                                                     ['quantize_linear'])
        skip_ops_with_weights_count = self.count_op(skip_quant_prog,
                                                    ops_with_weights)
        skip_ops_without_weights_count = self.count_op(skip_quant_prog,
                                                       ops_without_weights)
        self.assertEqual(skip_ops_without_weights_count,
                         ops_without_weights_count)
        self.assertEqual(skip_ops_with_weights_count, ops_with_weights_count)
        self.assertEqual(skip_quantizers_count_in_qat + 2,
                         quantizers_count_in_qat)

        skip_quant_prog_onnx = convert(skip_quant_prog, place, config=config)
        skip_quantizers_count_in_onnx = self.count_op(skip_quant_prog_onnx,
                                                      ['quantize_linear'])
        self.assertEqual(quantizers_count, skip_quantizers_count_in_onnx)

    def count_op(self, prog, ops: List[str]):
182 183 184 185
        graph = paddle.fluid.framework.IrGraph(
            paddle.framework.core.Graph(prog.desc), for_test=False)
        op_nums = 0
        for op in graph.all_op_nodes():
186
            if op.name() in ops:
187
                op_nums += 1
188
        return op_nums
189

190 191 192

if __name__ == '__main__':
    unittest.main()