test_quantization_pass.py 15.4 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
W
WangZhen 已提交
20
import paddle
21
from paddle.fluid.framework import IrGraph
22
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
W
WangZhen 已提交
23
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
24 25
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
W
WangZhen 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
from paddle.fluid import core


def linear_fc(num):
    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        hidden = fluid.layers.fc(hidden, size=128, act='relu')
    loss = fluid.layers.cross_entropy(input=hidden, label=label)
    loss = fluid.layers.mean(loss)
    return loss


def residual_block(num):
    def conv_bn_layer(input,
                      ch_out,
                      filter_size,
                      stride,
                      padding,
                      act='relu',
                      bias_attr=False):
        tmp = fluid.layers.conv2d(
            input=input,
            filter_size=filter_size,
            num_filters=ch_out,
            stride=stride,
            padding=padding,
            act=None,
            bias_attr=bias_attr)
        return fluid.layers.batch_norm(input=tmp, act=act)

    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
        short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
        hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
    fc = fluid.layers.fc(input=hidden, size=10)
    loss = fluid.layers.cross_entropy(input=fc, label=label)
    loss = fluid.layers.mean(loss)
    return loss


W
WangZhen 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
def conv_net(img, label):
    conv_pool_1 = fluid.nets.simple_img_conv_pool(
        input=img,
        filter_size=5,
        num_filters=20,
        pool_size=2,
        pool_stride=2,
        act="relu")
    conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
    conv_pool_2 = fluid.nets.simple_img_conv_pool(
        input=conv_pool_1,
        filter_size=5,
        num_filters=50,
        pool_size=2,
        pool_stride=2,
        act="relu")
    prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
    loss = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_loss = fluid.layers.mean(loss)
    return avg_loss


93
class TestQuantizationTransformPass(unittest.TestCase):
W
WangZhen 已提交
94 95 96 97 98 99
    def setUp(self):
        self.quantizable_op_and_inputs = {
            'conv2d': ['Input', 'Filter'],
            'depthwise_conv2d': ['Input', 'Filter'],
            'mul': ['X', 'Y']
        }
100
        self.quantizable_grad_op_inputs = {
W
WangZhen 已提交
101 102 103 104 105
            'conv2d_grad': ['Input', 'Filter'],
            'depthwise_conv2d_grad': ['Input', 'Filter'],
            'mul_grad': ['X', 'Y']
        }

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    def check_program(self, transform_pass, program):
        quantized_ops = set()
        for block in program.blocks:
            for op in block.ops:
                # check forward
                if op.type in self.quantizable_op_and_inputs:
                    for arg_name in op.input_arg_names:
                        self.assertTrue(
                            arg_name.endswith('.quantized.dequantized'))
                        quantized_ops.add(arg_name)

            for op in block.ops:
                # check backward
                if op.type in self.quantizable_grad_op_inputs:
                    for pname in self.quantizable_grad_op_inputs[op.type]:
                        arg_name = op.input(pname)[0]
                        self.assertTrue(
                            arg_name.endswith('.quantized.dequantized'))
                        self.assertTrue(arg_name in quantized_ops)

W
WangZhen 已提交
126 127 128 129 130 131 132
    def linear_fc_quant(self, quant_type):
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            loss = linear_fc(3)
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
133
        exe = fluid.Executor(fluid.CPUPlace())
134
        graph = IrGraph(core.Graph(main.desc), for_test=False)
135 136 137 138 139
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
            program_exe=exe,
            activation_quantize_type=quant_type)
        transform_pass.apply(graph)
W
WangZhen 已提交
140 141 142 143
        marked_nodes = set()
        for op in graph.all_ops():
            if op.name().find('quantize') > -1:
                marked_nodes.add(op)
144 145 146
        graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
        program = graph.to_program()
        self.check_program(transform_pass, program)
147
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
148 149 150 151 152
        val_marked_nodes = set()
        for op in val_graph.all_ops():
            if op.name().find('quantize') > -1:
                val_marked_nodes.add(op)
        val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
W
WangZhen 已提交
153

154
    def test_linear_fc_quant_abs_max(self):
W
WangZhen 已提交
155 156 157
        self.act_quant_op_type = 'fake_quantize_abs_max'
        self.linear_fc_quant('abs_max')

158
    def test_linear_fc_quant_range_abs_max(self):
W
WangZhen 已提交
159 160 161 162 163 164 165 166 167 168
        self.act_quant_op_type = 'fake_quantize_range_abs_max'
        self.linear_fc_quant('range_abs_max')

    def residual_block_quant(self, quant_type):
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            loss = residual_block(2)
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
169
        exe = fluid.Executor(fluid.CPUPlace())
170
        graph = IrGraph(core.Graph(main.desc), for_test=False)
171 172 173 174 175
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
            program_exe=exe,
            activation_quantize_type=quant_type)
        transform_pass.apply(graph)
W
WangZhen 已提交
176 177 178 179
        marked_nodes = set()
        for op in graph.all_ops():
            if op.name().find('quantize') > -1:
                marked_nodes.add(op)
180 181 182
        graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
        program = graph.to_program()
        self.check_program(transform_pass, program)
183
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
184 185 186 187 188
        val_marked_nodes = set()
        for op in val_graph.all_ops():
            if op.name().find('quantize') > -1:
                val_marked_nodes.add(op)
        val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
W
WangZhen 已提交
189

190
    def test_residual_block_abs_max(self):
W
WangZhen 已提交
191 192 193
        self.act_quant_op_type = 'fake_quantize_abs_max'
        self.residual_block_quant('abs_max')

194
    def test_residual_block_range_abs_max(self):
W
WangZhen 已提交
195 196 197 198
        self.act_quant_op_type = 'fake_quantize_range_abs_max'
        self.residual_block_quant('range_abs_max')


W
WangZhen 已提交
199 200
class TestQuantizationFreezePass(unittest.TestCase):
    def freeze_graph(self, use_cuda, seed, quant_type):
W
WangZhen 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        def build_program(main, startup, is_test):
            main.random_seed = seed
            startup.random_seed = seed
            with fluid.unique_name.guard():
                with fluid.program_guard(main, startup):
                    img = fluid.layers.data(
                        name='image', shape=[1, 28, 28], dtype='float32')
                    label = fluid.layers.data(
                        name='label', shape=[1], dtype='int64')
                    loss = conv_net(img, label)
                    if not is_test:
                        opt = fluid.optimizer.Adam(learning_rate=0.001)
                        opt.minimize(loss)
            return [img, label], loss

        random.seed(0)
        np.random.seed(0)

        main = fluid.Program()
        startup = fluid.Program()
        test_program = fluid.Program()
        feeds, loss = build_program(main, startup, False)
        build_program(test_program, startup, True)
        test_program = test_program.clone(for_test=True)
        main_graph = IrGraph(core.Graph(main.desc), for_test=False)
W
WangZhen 已提交
226
        test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
W
WangZhen 已提交
227 228 229

        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
        exe = fluid.Executor(place)
W
WangZhen 已提交
230 231 232
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup)
W
WangZhen 已提交
233
        transform_pass = QuantizationTransformPass(
W
WangZhen 已提交
234 235 236
            scope=scope, program_exe=exe, activation_quantize_type=quant_type)
        transform_pass.apply(main_graph)
        transform_pass.apply(test_graph)
237 238 239 240 241 242 243 244 245 246 247
        dev_name = '_gpu_' if use_cuda else '_cpu_'
        marked_nodes = set()
        for op in main_graph.all_ops():
            if op.name().find('quantize') > -1:
                marked_nodes.add(op)
        main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
        marked_nodes = set()
        for op in test_graph.all_ops():
            if op.name().find('quantize') > -1:
                marked_nodes.add(op)
        test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
W
WangZhen 已提交
248

249 250
        quantized_main_program = main_graph.to_program()
        quantized_test_program = test_graph.to_program()
251
        iters = 5
252
        batch_size = 8
W
WangZhen 已提交
253

254 255 256 257 258
        #train_exe = fluid.ParallelExecutor(
        #    main_program=quantized_main_program,
        #    use_cuda=bool(use_cuda),
        #    loss_name=loss.name,
        #    scope=scope)
W
WangZhen 已提交
259 260 261 262 263 264 265
        train_reader = paddle.batch(
            paddle.reader.shuffle(
                paddle.dataset.mnist.train(), buf_size=500),
            batch_size=batch_size)
        test_reader = paddle.batch(
            paddle.dataset.mnist.test(), batch_size=batch_size)
        feeder = fluid.DataFeeder(feed_list=feeds, place=place)
W
WangZhen 已提交
266
        with fluid.scope_guard(scope):
W
WangZhen 已提交
267 268
            for _ in range(iters):
                data = next(train_reader())
269 270 271 272 273
                loss_v = exe.run(program=quantized_main_program,
                                 feed=feeder.feed(data),
                                 fetch_list=[loss])
                #loss_v = train_exe.run(feed=feeder.feed(data),
                #                       fetch_list=[loss.name])
274
                #print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
W
WangZhen 已提交
275

276 277 278 279 280 281 282 283 284 285 286 287 288
        test_data = next(test_reader())
        with fluid.program_guard(quantized_test_program):
            w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
                                             quantized_test_program)
        # Testing
        with fluid.scope_guard(scope):
            test_loss1, w_quant = exe.run(program=quantized_test_program,
                                          feed=feeder.feed(test_data),
                                          fetch_list=[loss, w_var])

        # Freeze graph for inference, but the weight of fc/conv is still float type.
        freeze_pass = QuantizationFreezePass(scope=scope, place=place)
        freeze_pass.apply(test_graph)
W
WangZhen 已提交
289
        marked_nodes = set()
290
        for op in test_graph.all_ops():
W
WangZhen 已提交
291 292
            if op.name().find('quantize') > -1:
                marked_nodes.add(op)
293 294
        test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
                        marked_nodes)
W
WangZhen 已提交
295

296 297 298 299 300 301
        server_program = test_graph.to_program()
        with fluid.scope_guard(scope):
            test_loss2, = exe.run(program=server_program,
                                  feed=feeder.feed(test_data),
                                  fetch_list=[loss])
        self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
302 303
        #print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
        #print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
304 305
        w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
        # Maybe failed, this is due to the calculation precision
306 307 308 309 310
        # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
        #print('{}: {}'.format('w_freeze' + dev_name + quant_type,
        #                      np.sum(w_freeze)))
        #print('{}: {}'.format('w_quant' + dev_name + quant_type,
        #                      np.sum(w_quant)))
311 312 313 314 315

        # Convert parameter to 8-bit.
        convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
        convert_int8_pass.apply(test_graph)
        marked_nodes = set()
W
WangZhen 已提交
316 317
        for op in test_graph.all_ops():
            if op.name().find('quantize') > -1:
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
                marked_nodes.add(op)
        test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
        server_program_int8 = test_graph.to_program()
        # Save the 8-bit parameter and model file.
        with fluid.scope_guard(scope):
            fluid.io.save_inference_model('server_int8' + dev_name + quant_type,
                                          ['image', 'label'], [loss], exe,
                                          server_program_int8)
            # Test whether the 8-bit parameter and model file can be loaded successfully.
            [infer, feed, fetch] = fluid.io.load_inference_model(
                'server_int8' + dev_name + quant_type, exe)
        # Check the loaded 8-bit weight.
        w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
        self.assertEqual(w_8bit.dtype, np.int8)
        self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
333 334 335
        #print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
        #print('{}: {}'.format('w_freeze' + dev_name + quant_type,
        #                      np.sum(w_freeze)))
336 337 338 339

        mobile_pass = TransformForMobilePass()
        mobile_pass.apply(test_graph)
        marked_nodes = set()
W
WangZhen 已提交
340 341
        for op in test_graph.all_ops():
            if op.name().find('quantize') > -1:
342 343 344 345 346 347 348 349 350
                marked_nodes.add(op)
        test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
                        marked_nodes)

        mobile_program = test_graph.to_program()
        with fluid.scope_guard(scope):
            fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type,
                                          ['image', 'label'], [loss], exe,
                                          mobile_program)
W
WangZhen 已提交
351

352
    def test_freeze_graph_cuda_dynamic(self):
W
WangZhen 已提交
353 354 355 356
        if fluid.core.is_compiled_with_cuda():
            with fluid.unique_name.guard():
                self.freeze_graph(True, seed=1, quant_type='abs_max')

357
    def test_freeze_graph_cpu_dynamic(self):
W
WangZhen 已提交
358 359
        with fluid.unique_name.guard():
            self.freeze_graph(False, seed=2, quant_type='abs_max')
W
WangZhen 已提交
360

361
    def test_freeze_graph_cuda_static(self):
W
WangZhen 已提交
362 363
        if fluid.core.is_compiled_with_cuda():
            with fluid.unique_name.guard():
W
WangZhen 已提交
364
                self.freeze_graph(True, seed=1, quant_type='range_abs_max')
W
WangZhen 已提交
365

366
    def test_freeze_graph_cpu_static(self):
W
WangZhen 已提交
367
        with fluid.unique_name.guard():
W
WangZhen 已提交
368
            self.freeze_graph(False, seed=2, quant_type='range_abs_max')
W
WangZhen 已提交
369 370


W
WangZhen 已提交
371 372
if __name__ == '__main__':
    unittest.main()