test_quantization_mkldnn_pass.py 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

import os
import random
17 18
import unittest

19
import numpy as np
20

21 22
import paddle
from paddle.fluid.framework import IrGraph
23 24 25 26 27 28
from paddle.framework import core
from paddle.static.quantization import (
    QuantInt8MkldnnPass,
    QuantizationFreezePass,
    QuantizationTransformPass,
)
29

P
pangyoki 已提交
30
paddle.enable_static()
31 32 33 34
os.environ["CPU_NUM"] = "1"


def conv_net(img, label):
35
    conv_out_1 = paddle.static.nn.conv2d(
36 37 38
        input=img,
        filter_size=5,
        num_filters=20,
39 40 41 42
        act='relu',
    )
    conv_pool_1 = paddle.nn.functional.max_pool2d(
        conv_out_1, kernel_size=2, stride=2
43
    )
44
    conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
45 46

    conv_out_2 = paddle.static.nn.conv2d(
47 48
        input=conv_pool_1,
        filter_size=5,
49 50
        num_filters=20,
        act='relu',
51
    )
52 53 54 55
    conv_pool_2 = paddle.nn.functional.max_pool2d(
        conv_out_2, kernel_size=2, stride=2
    )
    prediction = paddle.static.nn.fc(conv_pool_2, size=10, activation='softmax')
56 57 58
    loss = paddle.nn.functional.cross_entropy(
        input=prediction, label=label, reduction='none', use_softmax=False
    )
59
    avg_loss = paddle.mean(loss)
60 61 62 63 64 65 66 67
    return avg_loss


class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
    def setUp(self):
        self.quantizable_op_and_inputs = {
            'conv2d': ['Input', 'Filter'],
            'depthwise_conv2d': ['Input', 'Filter'],
68
            'mul': ['X', 'Y'],
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        }

    def check_program(self, program):
        for block in program.blocks:
            for op in block.ops:
                if op.type in self.quantizable_op_and_inputs:
                    for arg_name in op.output_arg_names:
                        # Check quantizable op's output is linked to
                        # fake_dequantize's output
                        self.assertTrue(arg_name.endswith('.dequantized'))

    def isinteger(self, x):
        return np.equal(np.mod(x, 1), 0)

    def build_program(self, main, startup, is_test, seed):
        main.random_seed = seed
        startup.random_seed = seed
86 87 88 89
        with paddle.utils.unique_name.guard():
            with paddle.static.program_guard(main, startup):
                img = paddle.static.data(
                    name='image', shape=[-1, 1, 28, 28], dtype='float32'
90
                )
91 92
                label = paddle.static.data(
                    name='label', shape=[-1, 1], dtype='int64'
93
                )
94 95
                loss = conv_net(img, label)
                if not is_test:
96
                    opt = paddle.optimizer.Adam(learning_rate=0.001)
97 98 99
                    opt.minimize(loss)
        return [img, label], loss

100 101 102 103 104 105 106 107 108
    def mkldnn_based_freeze_graph(
        self,
        use_cuda,
        seed,
        activation_quant_type,
        weight_quant_type='abs_max',
        quant_perf=False,
        for_ci=False,
    ):
109 110 111
        random.seed(0)
        np.random.seed(0)

112 113 114
        main = paddle.static.Program()
        startup = paddle.static.Program()
        test_program = paddle.static.Program()
115 116 117 118 119 120
        feeds, loss = self.build_program(main, startup, False, seed)
        self.build_program(test_program, startup, True, seed)
        test_program = test_program.clone(for_test=True)
        main_graph = IrGraph(core.Graph(main.desc), for_test=False)
        test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)

121 122 123 124
        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        scope = paddle.static.global_scope()
        with paddle.static.scope_guard(scope):
125
            exe.run(startup)
W
Wojciech Uss 已提交
126
        # Apply the QuantizationTransformPass
127 128 129 130
        transform_pass = QuantizationTransformPass(
            scope=scope,
            place=place,
            activation_quantize_type=activation_quant_type,
131 132
            weight_quantize_type=weight_quant_type,
        )
133
        transform_pass.apply(main_graph)
134 135 136 137
        transform_pass = QuantizationTransformPass(
            scope=scope,
            place=place,
            activation_quantize_type=activation_quant_type,
138 139
            weight_quantize_type=weight_quant_type,
        )
140 141
        transform_pass.apply(test_graph)

142
        build_strategy = paddle.static.BuildStrategy()
143 144
        build_strategy.memory_optimize = False
        build_strategy.enable_inplace = False
145 146 147
        binary = paddle.static.CompiledProgram(
            main_graph.graph
        ).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
148 149 150 151
        quantized_test_program = test_graph.to_program()
        iters = 5
        batch_size = 8

152 153 154 155 156 157 158
        train_reader = paddle.batch(
            paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
            batch_size=batch_size,
        )
        test_reader = paddle.batch(
            paddle.dataset.mnist.test(), batch_size=batch_size
        )
159
        feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
160 161

        # Training the model to get the weights value
162
        with paddle.static.scope_guard(scope):
163 164
            for _ in range(iters):
                data = next(train_reader())
165 166 167
                loss_v = exe.run(
                    binary, feed=feeder.feed(data), fetch_list=[loss]
                )
168 169 170

        # Freeze graph for inference, but the weight of fc/conv is still float type.
        freeze_pass = QuantizationFreezePass(
171 172
            scope=scope, place=place, weight_quantize_type=weight_quant_type
        )
173 174 175
        freeze_pass.apply(test_graph)

        # Transform quantized graph for MKL-DNN INT8 inference
W
Wojciech Uss 已提交
176
        mkldnn_int8_pass = QuantInt8MkldnnPass(_scope=scope, _place=place)
177 178 179 180 181 182 183
        mkldnn_int8_pass.apply(test_graph)
        dev_name = '_cpu_'
        if not for_ci:
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
184
            test_graph.draw(
185 186 187 188 189 190 191 192
                '.',
                'test_mkldnn'
                + dev_name
                + activation_quant_type
                + '_'
                + weight_quant_type,
                marked_nodes,
            )
193
        mkldnn_program = test_graph.to_program()
194 195 196 197

        # Check the transformation weights of conv2d and mul
        conv_w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
        mul_w_mkldnn = np.array(scope.find_var('fc_0.w_0').get_tensor())
198
        # Check if weights are still integer
199 200
        self.assertFalse(self.isinteger(np.sum(conv_w_mkldnn)))
        self.assertFalse(self.isinteger(np.sum(mul_w_mkldnn)))
201

W
Wojciech Uss 已提交
202
        # Check if the conv2d output and mul output are correctly linked to fake_dequantize's
203 204 205
        # output
        self.check_program(mkldnn_program)
        if not for_ci:
206 207 208 209 210 211 212
            print(
                '{}: {}'.format(
                    'w_mkldnn'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
213
                    np.sum(mul_w_mkldnn),
214 215
                )
            )
216 217

    def test_mkldnn_graph_cpu_static(self):
218
        with paddle.utils.unique_name.guard():
219 220 221 222 223
            self.mkldnn_based_freeze_graph(
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='abs_max',
224 225
                for_ci=True,
            )
226 227 228 229 230
            self.mkldnn_based_freeze_graph(
                False,
                seed=2,
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='abs_max',
231 232
                for_ci=True,
            )
233 234 235 236


if __name__ == '__main__':
    unittest.main()