test_quantization_pass.py 36.5 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

Z
Zhen Wang 已提交
15
import os
W
WangZhen 已提交
16
import random
17 18
import unittest

W
WangZhen 已提交
19
import numpy as np
20

W
WangZhen 已提交
21
import paddle
22
import paddle.fluid as fluid
23
from paddle.fluid.framework import IrGraph
24 25 26 27 28 29 30 31 32
from paddle.framework import core
from paddle.static.quantization import (
    AddQuantDequantPass,
    ConvertToInt8Pass,
    QuantizationFreezePass,
    QuantizationTransformPass,
    QuantizationTransformPassV2,
    TransformForMobilePass,
)
W
WangZhen 已提交
33

P
pangyoki 已提交
34 35
paddle.enable_static()

Z
Zhen Wang 已提交
36 37 38
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"

W
WangZhen 已提交
39 40

def linear_fc(num):
41 42 43 44
    data = paddle.static.data(
        name='image', shape=[-1, 1, 32, 32], dtype='float32'
    )
    label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
W
WangZhen 已提交
45
    hidden = data
46
    for _ in range(num):
47
        hidden = paddle.static.nn.fc(hidden, size=128, activation='relu')
48 49 50
    loss = paddle.nn.functional.cross_entropy(
        input=hidden, label=label, reduction='none', use_softmax=False
    )
51
    loss = paddle.mean(loss)
W
WangZhen 已提交
52 53 54
    return loss


55
def residual_block(num, quant_skip_pattern=None):
56 57 58
    def conv_bn_layer(
        input, ch_out, filter_size, stride, padding, act='relu', bias_attr=False
    ):
59
        tmp = paddle.static.nn.conv2d(
60 61 62 63 64 65 66 67
            input=input,
            filter_size=filter_size,
            num_filters=ch_out,
            stride=stride,
            padding=padding,
            act=None,
            bias_attr=bias_attr,
        )
68
        return paddle.static.nn.batch_norm(input=tmp, act=act)
W
WangZhen 已提交
69

70
    data = paddle.static.data(
71 72 73 74
        name='image',
        shape=[1, 1, 32, 32],
        dtype='float32',
    )
75
    label = paddle.static.data(name='label', shape=[1, 1], dtype='int64')
W
WangZhen 已提交
76
    hidden = data
77
    for _ in range(num):
W
WangZhen 已提交
78 79
        conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
        short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
80 81 82
        hidden = paddle.add(x=conv, y=short)
        hidden = paddle.nn.functional.relu(hidden)
    matmul_weight = paddle.static.create_parameter(
83 84
        shape=[1, 16, 32, 32], dtype='float32'
    )
K
kangguangli 已提交
85
    hidden = paddle.matmul(hidden, matmul_weight, True, True)
86
    if quant_skip_pattern:
87
        with paddle.static.name_scope(quant_skip_pattern):
C
ccrrong 已提交
88
            pool = paddle.nn.functional.avg_pool2d(
89
                hidden, kernel_size=2, stride=2
90
            )
91
    else:
92 93
        pool = paddle.nn.functional.avg_pool2d(hidden, kernel_size=2, stride=2)
    fc = paddle.static.nn.fc(pool, size=10)
94 95 96
    loss = paddle.nn.functional.cross_entropy(
        input=fc, label=label, reduction='none', use_softmax=False
    )
97
    loss = paddle.mean(loss)
W
WangZhen 已提交
98 99 100
    return loss


101
def conv_net(img, label, quant_skip_pattern):
102
    conv_out_1 = paddle.static.nn.conv2d(
103 104 105
        input=img,
        filter_size=5,
        num_filters=20,
106 107 108 109
        act='relu',
    )
    conv_pool_1 = paddle.nn.functional.max_pool2d(
        conv_out_1, kernel_size=2, stride=2
110
    )
111
    conv_pool_1 = paddle.static.nn.batch_norm(conv_pool_1)
112 113

    conv_out_2 = paddle.static.nn.conv2d(
114 115
        input=conv_pool_1,
        filter_size=5,
116 117 118 119 120
        num_filters=20,
        act='relu',
    )
    conv_pool_2 = paddle.nn.functional.avg_pool2d(
        conv_out_2, kernel_size=2, stride=2
121
    )
122 123 124
    hidden = paddle.static.nn.fc(conv_pool_2, size=100, activation='relu')
    with paddle.static.name_scope(quant_skip_pattern):
        prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
125 126 127
    loss = paddle.nn.functional.cross_entropy(
        input=prediction, label=label, reduction='none', use_softmax=False
    )
128
    avg_loss = paddle.mean(loss)
W
WangZhen 已提交
129 130 131
    return avg_loss


132
class TestQuantizationTransformPass(unittest.TestCase):
W
WangZhen 已提交
133 134 135 136
    def setUp(self):
        self.quantizable_op_and_inputs = {
            'conv2d': ['Input', 'Filter'],
            'depthwise_conv2d': ['Input', 'Filter'],
137
            'mul': ['X', 'Y'],
W
WangZhen 已提交
138
        }
139
        self.quantizable_grad_op_inputs = {
W
WangZhen 已提交
140 141
            'conv2d_grad': ['Input', 'Filter'],
            'depthwise_conv2d_grad': ['Input', 'Filter'],
142
            'mul_grad': ['X', 'Y'],
W
WangZhen 已提交
143 144
        }

145
    def check_program(self, program):
146 147 148 149 150 151 152
        quantized_ops = set()
        for block in program.blocks:
            for op in block.ops:
                # check forward
                if op.type in self.quantizable_op_and_inputs:
                    for arg_name in op.input_arg_names:
                        self.assertTrue(
153 154
                            arg_name.endswith('.quantized.dequantized')
                        )
155 156 157 158 159 160 161 162
                        quantized_ops.add(arg_name)

            for op in block.ops:
                # check backward
                if op.type in self.quantizable_grad_op_inputs:
                    for pname in self.quantizable_grad_op_inputs[op.type]:
                        arg_name = op.input(pname)[0]
                        self.assertTrue(
163 164
                            arg_name.endswith('.quantized.dequantized')
                        )
165 166
                        self.assertTrue(arg_name in quantized_ops)

167 168 169
    def linear_fc_quant(
        self, activation_quant_type, weight_quantize_type, for_ci=True
    ):
170 171 172
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
W
WangZhen 已提交
173
            loss = linear_fc(3)
174
            opt = paddle.optimizer.Adam(learning_rate=0.001)
W
WangZhen 已提交
175
            opt.minimize(loss)
176
        place = paddle.CPUPlace()
177
        graph = IrGraph(core.Graph(main.desc), for_test=False)
178
        transform_pass = QuantizationTransformPass(
179
            scope=paddle.static.global_scope(),
180
            place=place,
181
            activation_quantize_type=activation_quant_type,
182 183
            weight_quantize_type=weight_quantize_type,
        )
184
        transform_pass.apply(graph)
Z
Zhen Wang 已提交
185
        if not for_ci:
Z
Zhen Wang 已提交
186 187 188 189
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
190 191 192
            graph.draw(
                '.', 'quantize_fc_' + activation_quant_type, marked_nodes
            )
193
        program = graph.to_program()
194
        self.check_program(program)
195
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
Z
Zhen Wang 已提交
196
        if not for_ci:
Z
Zhen Wang 已提交
197 198 199 200
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
201 202 203
            val_graph.draw(
                '.', 'val_fc_' + activation_quant_type, val_marked_nodes
            )
W
WangZhen 已提交
204

205
    def test_linear_fc_quant_abs_max(self):
206
        self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
207

208
    def test_linear_fc_quant_range_abs_max(self):
209
        self.linear_fc_quant('range_abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
210

211
    def test_linear_fc_quant_moving_average_abs_max(self):
212 213 214 215 216 217 218 219 220 221 222
        self.linear_fc_quant(
            'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True
        )

    def residual_block_quant(
        self,
        activation_quant_type,
        weight_quantize_type,
        quantizable_op_type,
        for_ci=True,
    ):
223 224 225
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
W
WangZhen 已提交
226
            loss = residual_block(2)
227
            opt = paddle.optimizer.Adam(learning_rate=0.001)
W
WangZhen 已提交
228
            opt.minimize(loss)
229
        place = paddle.CPUPlace()
230
        graph = IrGraph(core.Graph(main.desc), for_test=False)
231
        transform_pass = QuantizationTransformPass(
232
            scope=paddle.static.global_scope(),
233
            place=place,
234
            activation_quantize_type=activation_quant_type,
235
            weight_quantize_type=weight_quantize_type,
236 237
            quantizable_op_type=quantizable_op_type,
        )
238
        transform_pass.apply(graph)
Z
Zhen Wang 已提交
239
        if not for_ci:
Z
Zhen Wang 已提交
240 241 242 243
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
244 245 246
            graph.draw(
                '.', 'quantize_residual_' + activation_quant_type, marked_nodes
            )
247
        program = graph.to_program()
248
        self.check_program(program)
249
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
Z
Zhen Wang 已提交
250
        if not for_ci:
Z
Zhen Wang 已提交
251 252 253 254
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
255 256 257
            val_graph.draw(
                '.', 'val_residual_' + activation_quant_type, val_marked_nodes
            )
W
WangZhen 已提交
258

259
    def test_residual_block_abs_max(self):
260
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
261 262 263
        self.residual_block_quant(
            'abs_max', 'abs_max', quantizable_op_type, for_ci=True
        )
W
WangZhen 已提交
264

265
    def test_residual_block_range_abs_max(self):
266
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
267 268 269
        self.residual_block_quant(
            'range_abs_max', 'abs_max', quantizable_op_type, for_ci=True
        )
W
WangZhen 已提交
270

271
    def test_residual_block_moving_average_abs_max(self):
272
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
273 274 275 276 277 278
        self.residual_block_quant(
            'moving_average_abs_max',
            'channel_wise_abs_max',
            quantizable_op_type,
            for_ci=True,
        )
279

W
WangZhen 已提交
280

W
WangZhen 已提交
281
class TestQuantizationFreezePass(unittest.TestCase):
282 283 284 285 286 287 288 289 290 291
    def freeze_graph(
        self,
        use_cuda,
        seed,
        activation_quant_type,
        bias_correction=False,
        weight_quant_type='abs_max',
        for_ci=True,
        quant_skip_pattern='skip_quant',
    ):
W
WangZhen 已提交
292 293 294
        def build_program(main, startup, is_test):
            main.random_seed = seed
            startup.random_seed = seed
295 296 297 298
            with paddle.utils.unique_name.guard():
                with paddle.static.program_guard(main, startup):
                    img = paddle.static.data(
                        name='image', shape=[-1, 1, 28, 28], dtype='float32'
299
                    )
300 301
                    label = paddle.static.data(
                        name='label', shape=[-1, 1], dtype='int64'
302
                    )
303
                    loss = conv_net(img, label, quant_skip_pattern)
W
WangZhen 已提交
304
                    if not is_test:
305
                        opt = paddle.optimizer.Adam(learning_rate=0.001)
W
WangZhen 已提交
306 307 308 309 310 311
                        opt.minimize(loss)
            return [img, label], loss

        random.seed(0)
        np.random.seed(0)

312 313 314
        main = paddle.static.Program()
        startup = paddle.static.Program()
        test_program = paddle.static.Program()
W
WangZhen 已提交
315 316 317 318
        feeds, loss = build_program(main, startup, False)
        build_program(test_program, startup, True)
        test_program = test_program.clone(for_test=True)
        main_graph = IrGraph(core.Graph(main.desc), for_test=False)
W
WangZhen 已提交
319
        test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
W
WangZhen 已提交
320

321 322 323 324
        place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
        exe = paddle.static.Executor(place)
        scope = paddle.static.global_scope()
        with paddle.static.scope_guard(scope):
W
WangZhen 已提交
325
            exe.run(startup)
W
WangZhen 已提交
326
        transform_pass = QuantizationTransformPass(
327 328
            scope=scope,
            place=place,
329
            activation_quantize_type=activation_quant_type,
330
            weight_quantize_type=weight_quant_type,
331 332
            skip_pattern=quant_skip_pattern,
        )
W
WangZhen 已提交
333
        transform_pass.apply(main_graph)
334 335 336 337 338
        transform_pass = QuantizationTransformPass(
            scope=scope,
            place=place,
            activation_quantize_type=activation_quant_type,
            weight_quantize_type=weight_quant_type,
339 340
            skip_pattern=quant_skip_pattern,
        )
W
WangZhen 已提交
341
        transform_pass.apply(test_graph)
342
        dev_name = '_gpu_' if use_cuda else '_cpu_'
Z
Zhen Wang 已提交
343
        if not for_ci:
Z
Zhen Wang 已提交
344 345 346 347
            marked_nodes = set()
            for op in main_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
348
            main_graph.draw(
349 350 351 352 353 354 355 356
                '.',
                'main'
                + dev_name
                + activation_quant_type
                + '_'
                + weight_quant_type,
                marked_nodes,
            )
Z
Zhen Wang 已提交
357 358 359 360
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
361
            test_graph.draw(
362 363 364 365 366 367 368 369
                '.',
                'test'
                + dev_name
                + activation_quant_type
                + '_'
                + weight_quant_type,
                marked_nodes,
            )
W
WangZhen 已提交
370

371
        build_strategy = paddle.static.BuildStrategy()
Z
Zhen Wang 已提交
372 373
        build_strategy.memory_optimize = False
        build_strategy.enable_inplace = False
374
        build_strategy.fuse_all_reduce_ops = False
375 376 377
        binary = paddle.static.CompiledProgram(
            main_graph.graph
        ).with_data_parallel(loss_name=loss.name, build_strategy=build_strategy)
378
        quantized_test_program = test_graph.to_program()
379
        iters = 5
380
        batch_size = 8
W
WangZhen 已提交
381

382 383 384 385 386 387 388
        train_reader = paddle.batch(
            paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
            batch_size=batch_size,
        )
        test_reader = paddle.batch(
            paddle.dataset.mnist.test(), batch_size=batch_size
        )
389 390
        feeder = paddle.fluid.DataFeeder(feed_list=feeds, place=place)
        with paddle.static.scope_guard(scope):
W
WangZhen 已提交
391 392
            for _ in range(iters):
                data = next(train_reader())
393 394 395
                loss_v = exe.run(
                    binary, feed=feeder.feed(data), fetch_list=[loss]
                )
Z
Zhen Wang 已提交
396
                if not for_ci:
397 398 399 400 401 402 403 404 405 406
                    print(
                        '{}: {}'.format(
                            'loss'
                            + dev_name
                            + activation_quant_type
                            + '_'
                            + weight_quant_type,
                            loss_v,
                        )
                    )
W
WangZhen 已提交
407

408
        test_data = next(test_reader())
409
        with paddle.static.program_guard(quantized_test_program):
410 411 412
            w_var = fluid.framework._get_var(
                'conv2d_1.w_0.quantized', quantized_test_program
            )
413
        # Testing
414
        with paddle.static.scope_guard(scope):
415 416 417 418 419
            test_loss1, w_quant = exe.run(
                program=quantized_test_program,
                feed=feeder.feed(test_data),
                fetch_list=[loss, w_var],
            )
420 421

        # Freeze graph for inference, but the weight of fc/conv is still float type.
422
        freeze_pass = QuantizationFreezePass(
423 424 425 426 427
            scope=scope,
            place=place,
            bias_correction=bias_correction,
            weight_quantize_type=weight_quant_type,
        )
428
        freeze_pass.apply(test_graph)
Z
Zhen Wang 已提交
429
        if not for_ci:
Z
Zhen Wang 已提交
430 431 432 433
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
434
            test_graph.draw(
435 436 437 438 439 440 441 442
                '.',
                'test_freeze'
                + dev_name
                + activation_quant_type
                + '_'
                + weight_quant_type,
                marked_nodes,
            )
W
WangZhen 已提交
443

444
        server_program = test_graph.to_program()
445
        with paddle.static.scope_guard(scope):
446 447 448 449 450
            (test_loss2,) = exe.run(
                program=server_program,
                feed=feeder.feed(test_data),
                fetch_list=[loss],
            )
451
        self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
Z
Zhen Wang 已提交
452
        if not for_ci:
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
            print(
                '{}: {}'.format(
                    'test_loss1'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
                    test_loss1,
                )
            )
            print(
                '{}: {}'.format(
                    'test_loss2'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
                    test_loss2,
                )
            )
473 474
        w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
        # Maybe failed, this is due to the calculation precision
475
        # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
Z
Zhen Wang 已提交
476
        if not for_ci:
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
            print(
                '{}: {}'.format(
                    'w_freeze'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
                    np.sum(w_freeze),
                )
            )
            print(
                '{}: {}'.format(
                    'w_quant'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
                    np.sum(w_quant),
                )
            )
497 498 499 500

        # Convert parameter to 8-bit.
        convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
        convert_int8_pass.apply(test_graph)
Z
Zhen Wang 已提交
501
        if not for_ci:
Z
Zhen Wang 已提交
502 503 504 505
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
506
            test_graph.draw(
507 508 509 510 511 512 513 514
                '.',
                'test_int8'
                + dev_name
                + activation_quant_type
                + '_'
                + weight_quant_type,
                marked_nodes,
            )
515 516
        server_program_int8 = test_graph.to_program()
        # Save the 8-bit parameter and model file.
517 518 519 520 521 522 523
        with paddle.static.scope_guard(scope):
            feed_list = ['image', 'label']
            feed_vars = [
                server_program_int8.global_block().var(name)
                for name in feed_list
            ]
            paddle.static.save_inference_model(
524 525 526 527
                'server_int8'
                + dev_name
                + activation_quant_type
                + '_'
528 529 530
                + weight_quant_type
                + '/model',
                feed_vars,
531 532
                [loss],
                exe,
533
                program=server_program_int8,
534
            )
535
            # Test whether the 8-bit parameter and model file can be loaded successfully.
536
            [infer, feed, fetch] = paddle.static.load_inference_model(
537 538 539 540
                'server_int8'
                + dev_name
                + activation_quant_type
                + '_'
541 542
                + weight_quant_type
                + '/model',
543 544
                exe,
            )
545 546 547 548
        # Check the loaded 8-bit weight.
        w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
        self.assertEqual(w_8bit.dtype, np.int8)
        self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
Z
Zhen Wang 已提交
549
        if not for_ci:
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
            print(
                '{}: {}'.format(
                    'w_8bit'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
                    np.sum(w_8bit),
                )
            )
            print(
                '{}: {}'.format(
                    'w_freeze'
                    + dev_name
                    + activation_quant_type
                    + '_'
                    + weight_quant_type,
                    np.sum(w_freeze),
                )
            )
570 571 572

        mobile_pass = TransformForMobilePass()
        mobile_pass.apply(test_graph)
Z
Zhen Wang 已提交
573
        if not for_ci:
Z
Zhen Wang 已提交
574 575 576 577
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
578
            test_graph.draw(
579 580 581 582 583 584 585 586
                '.',
                'test_mobile'
                + dev_name
                + activation_quant_type
                + '_'
                + weight_quant_type,
                marked_nodes,
            )
587 588

        mobile_program = test_graph.to_program()
589 590 591 592 593 594
        with paddle.static.scope_guard(scope):
            feed_list = ['image', 'label']
            feed_vars = [
                mobile_program.global_block().var(name) for name in feed_list
            ]
            paddle.static.save_inference_model(
595 596 597 598
                'mobile_int8'
                + dev_name
                + activation_quant_type
                + '_'
599 600 601
                + weight_quant_type
                + '/model',
                feed_vars,
602 603
                [loss],
                exe,
604
                program=mobile_program,
605
            )
W
WangZhen 已提交
606

607
    def test_freeze_graph_cuda_dynamic(self):
608 609
        if core.is_compiled_with_cuda():
            with paddle.utils.unique_name.guard():
610 611 612 613 614 615 616
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True,
                )
617
            with paddle.utils.unique_name.guard():
618 619 620 621 622 623 624
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True,
                )
W
WangZhen 已提交
625

626
    def test_freeze_graph_cpu_dynamic(self):
627
        with paddle.utils.unique_name.guard():
628 629 630 631 632 633 634 635 636 637 638 639 640 641
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='abs_max',
                weight_quant_type='abs_max',
                for_ci=True,
            )
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True,
            )
W
WangZhen 已提交
642

643
    def test_freeze_graph_cuda_static(self):
644 645
        if core.is_compiled_with_cuda():
            with paddle.utils.unique_name.guard():
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='range_abs_max',
                    bias_correction=True,
                    weight_quant_type='abs_max',
                    for_ci=True,
                )
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='range_abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True,
                )
661 662 663 664 665
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    weight_quant_type='abs_max',
666 667 668 669 670 671 672 673 674
                    for_ci=True,
                )
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='range_abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True,
                )
675 676 677 678 679
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    weight_quant_type='channel_wise_abs_max',
680 681
                    for_ci=True,
                )
X
XGZhang 已提交
682 683 684 685 686 687
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    bias_correction=True,
                    weight_quant_type='channel_wise_abs_max',
688 689
                    for_ci=True,
                )
W
WangZhen 已提交
690

691
    def test_freeze_graph_cpu_static(self):
692
        with paddle.utils.unique_name.guard():
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='abs_max',
                for_ci=True,
            )
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='abs_max',
                for_ci=True,
            )
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True,
            )
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True,
            )
W
WangZhen 已提交
721 722


723
def quant_dequant_residual_block(num, quant_skip_pattern=None):
724 725 726
    def conv_bn_layer(
        input, ch_out, filter_size, stride, padding, act='relu', bias_attr=False
    ):
727
        tmp = paddle.static.nn.conv2d(
728 729 730 731 732 733 734 735
            input=input,
            filter_size=filter_size,
            num_filters=ch_out,
            stride=stride,
            padding=padding,
            act=None,
            bias_attr=bias_attr,
        )
736
        return paddle.static.nn.batch_norm(input=tmp, act=act)
737

738 739 740 741 742
    data1 = paddle.static.data(
        name='image', shape=[-1, 1, 32, 32], dtype='float32'
    )
    data2 = paddle.static.data(
        name='matmul_input', shape=[-1, 16, 32, 32], dtype='float32'
743
    )
744
    label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
745
    hidden = data1
746
    for _ in range(num):
747 748
        conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
        short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
749 750
        hidden = paddle.add(x=conv, y=short)
        hidden = paddle.nn.functional.relu(hidden)
K
kangguangli 已提交
751
    hidden = paddle.matmul(hidden, data2, True, True)
752
    if isinstance(quant_skip_pattern, str):
753
        with paddle.static.name_scope(quant_skip_pattern):
C
ccrrong 已提交
754
            pool1 = paddle.nn.functional.avg_pool2d(
755
                hidden, kernel_size=2, stride=2
756
            )
C
ccrrong 已提交
757
            pool2 = paddle.nn.functional.max_pool2d(
758
                hidden, kernel_size=2, stride=2
759
            )
760 761
            pool_add = paddle.add(pool1, pool2)
            pool_add = paddle.nn.functional.relu(pool_add)
762
    elif isinstance(quant_skip_pattern, list):
763 764 765
        assert (
            len(quant_skip_pattern) > 1
        ), 'test config error: the len of quant_skip_pattern list should be greater than 1.'
766
        with paddle.static.name_scope(quant_skip_pattern[0]):
C
ccrrong 已提交
767
            pool1 = paddle.nn.functional.avg_pool2d(
768
                hidden, kernel_size=2, stride=2
769
            )
C
ccrrong 已提交
770
            pool2 = paddle.nn.functional.max_pool2d(
771
                hidden, kernel_size=2, stride=2
772
            )
773 774 775
        with paddle.static.name_scope(quant_skip_pattern[1]):
            pool_add = paddle.add(pool1, pool2)
            pool_add = paddle.nn.functional.relu(pool_add)
776
    else:
777 778 779 780 781
        pool1 = paddle.nn.functional.avg_pool2d(hidden, kernel_size=2, stride=2)
        pool2 = paddle.nn.functional.max_pool2d(hidden, kernel_size=2, stride=2)
        pool_add = paddle.add(pool1, pool2)
        pool_add = paddle.nn.functional.relu(pool_add)
    fc = paddle.static.nn.fc(pool_add, size=10)
782 783 784
    loss = paddle.nn.functional.cross_entropy(
        input=fc, label=label, reduction='none', use_softmax=False
    )
785
    loss = paddle.mean(loss)
786 787 788
    return loss


789 790 791 792 793
class TestAddQuantDequantPass(unittest.TestCase):
    def setUp(self):
        self._target_ops = {'elementwise_add', 'pool2d'}
        self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}

794
    def check_graph(self, graph, skip_pattern=None):
795 796 797
        ops = graph.all_op_nodes()
        for op_node in ops:
            if op_node.name() in self._target_ops:
798 799
                user_skipped = False
                if isinstance(skip_pattern, list):
800 801 802 803 804 805
                    user_skipped = op_node.op().has_attr(
                        "op_namescope"
                    ) and any(
                        pattern in op_node.op().attr("op_namescope")
                        for pattern in skip_pattern
                    )
806
                elif isinstance(skip_pattern, str):
807 808 809 810 811
                    user_skipped = (
                        op_node.op().has_attr("op_namescope")
                        and op_node.op().attr("op_namescope").find(skip_pattern)
                        != -1
                    )
812 813

                if user_skipped:
814 815
                    continue

816 817
                in_nodes_all_not_persistable = True
                for input_name in op_node.input_arg_names():
818 819 820 821 822 823 824
                    in_node = graph._find_node_by_name(
                        op_node.inputs, input_name
                    )
                    in_nodes_all_not_persistable = (
                        in_nodes_all_not_persistable
                        and not in_node.persistable()
                    )
825 826 827 828 829 830
                if not in_nodes_all_not_persistable:
                    continue
                input_names = op_node.input_arg_names()
                for input_name in input_names:
                    self.assertTrue(input_name.endswith('.quant_dequant'))

831 832 833
    def residual_block_quant(
        self, quantizable_op_type, skip_pattern=None, for_ci=True
    ):
834 835 836
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
837
            loss = quant_dequant_residual_block(2, skip_pattern)
838
            opt = paddle.optimizer.Adam(learning_rate=0.001)
839
            opt.minimize(loss)
840
        place = paddle.CPUPlace()
841 842
        graph = IrGraph(core.Graph(main.desc), for_test=False)
        add_quant_dequant_pass = AddQuantDequantPass(
843
            scope=paddle.static.global_scope(),
844 845
            place=place,
            skip_pattern=skip_pattern,
846 847
            quantizable_op_type=quantizable_op_type,
        )
848 849 850 851 852 853 854
        add_quant_dequant_pass.apply(graph)
        if not for_ci:
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quant') > -1:
                    marked_nodes.add(op)
            graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
855
        self.check_graph(graph, skip_pattern)
856 857 858 859 860 861 862 863 864 865
        program = graph.to_program()
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
        if not for_ci:
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quant') > -1:
                    val_marked_nodes.add(op)
            val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)

    def test_residual_block(self):
866
        quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
867 868 869
        self.residual_block_quant(
            quantizable_op_type, skip_pattern=None, for_ci=True
        )
870 871

    def test_residual_block_skip_pattern(self):
872
        quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
873 874 875
        self.residual_block_quant(
            quantizable_op_type, skip_pattern='skip_quant', for_ci=True
        )
876

Z
zhangchunle 已提交
877
    def test_residual_block_skip_pattern_1(self):
878
        quantizable_op_type = ['elementwise_add', 'pool2d', 'mul', 'matmul']
879 880 881 882 883
        self.residual_block_quant(
            quantizable_op_type,
            skip_pattern=['skip_quant1', 'skip_quant2'],
            for_ci=True,
        )
884

885

886 887 888 889 890
class TestQuantizationTransformPassV2(unittest.TestCase):
    def setUp(self):
        self.quantizable_op_and_inputs = {
            'conv2d': ['Input', 'Filter'],
            'depthwise_conv2d': ['Input', 'Filter'],
891
            'mul': ['X', 'Y'],
892 893 894 895
        }
        self.quantizable_grad_op_inputs = {
            'conv2d_grad': ['Input', 'Filter'],
            'depthwise_conv2d_grad': ['Input', 'Filter'],
896
            'mul_grad': ['X', 'Y'],
897 898 899 900 901 902 903 904 905 906
        }

    def check_program(self, program):
        quantized_ops = set()
        for block in program.blocks:
            for op in block.ops:
                # check forward
                if op.type in self.quantizable_op_and_inputs:
                    for arg_name in op.input_arg_names:
                        self.assertTrue(
907 908
                            arg_name.endswith('.quantized.dequantized')
                        )
909 910 911 912 913 914 915 916
                        quantized_ops.add(arg_name)

            for op in block.ops:
                # check backward
                if op.type in self.quantizable_grad_op_inputs:
                    for pname in self.quantizable_grad_op_inputs[op.type]:
                        arg_name = op.input(pname)[0]
                        self.assertTrue(
917 918
                            arg_name.endswith('.quantized.dequantized')
                        )
919 920
                        self.assertTrue(arg_name in quantized_ops)

921 922 923
    def linear_fc_quant(
        self, activation_quant_type, weight_quantize_type, for_ci=True
    ):
924 925 926
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
927
            loss = linear_fc(3)
928
            opt = paddle.optimizer.Adam(learning_rate=0.001)
929
            opt.minimize(loss)
930
        place = paddle.CPUPlace()
931 932
        graph = IrGraph(core.Graph(main.desc), for_test=False)
        transform_pass = QuantizationTransformPassV2(
933
            scope=paddle.static.global_scope(),
934 935
            place=place,
            activation_quantize_type=activation_quant_type,
936 937
            weight_quantize_type=weight_quantize_type,
        )
938 939 940 941 942 943
        transform_pass.apply(graph)
        if not for_ci:
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
944 945 946
            graph.draw(
                '.', 'quantize_fc_' + activation_quant_type, marked_nodes
            )
947 948 949 950 951 952 953 954
        program = graph.to_program()
        self.check_program(program)
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
        if not for_ci:
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
955 956 957
            val_graph.draw(
                '.', 'val_fc_' + activation_quant_type, val_marked_nodes
            )
958 959 960 961 962 963 964

    def test_linear_fc_quant_abs_max(self):
        self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)

    def test_linear_fc_quant_channel_wise_abs_max(self):
        self.linear_fc_quant('abs_max', 'channel_wise_abs_max', for_ci=True)

965 966 967 968 969 970 971
    def residual_block_quant(
        self,
        activation_quant_type,
        weight_quantize_type,
        quantizable_op_type,
        for_ci=True,
    ):
972 973 974
        main = paddle.static.Program()
        startup = paddle.static.Program()
        with paddle.static.program_guard(main, startup):
975
            loss = residual_block(2)
976
            opt = paddle.optimizer.Adam(learning_rate=0.001)
977
            opt.minimize(loss)
978
        place = paddle.CPUPlace()
979 980
        graph = IrGraph(core.Graph(main.desc), for_test=False)
        transform_pass = QuantizationTransformPass(
981
            scope=paddle.static.global_scope(),
982 983 984
            place=place,
            activation_quantize_type=activation_quant_type,
            weight_quantize_type=weight_quantize_type,
985 986
            quantizable_op_type=quantizable_op_type,
        )
987 988 989 990 991 992
        transform_pass.apply(graph)
        if not for_ci:
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
993 994 995
            graph.draw(
                '.', 'quantize_residual_' + activation_quant_type, marked_nodes
            )
996 997 998 999 1000 1001 1002 1003
        program = graph.to_program()
        self.check_program(program)
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
        if not for_ci:
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
1004 1005 1006
            val_graph.draw(
                '.', 'val_residual_' + activation_quant_type, val_marked_nodes
            )
1007 1008 1009

    def test_residual_block_abs_max(self):
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
1010 1011 1012
        self.residual_block_quant(
            'abs_max', 'abs_max', quantizable_op_type, for_ci=True
        )
1013 1014 1015

    def test_residual_block_channel_wise_abs_max(self):
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul']
1016 1017 1018
        self.residual_block_quant(
            'abs_max', 'channel_wise_abs_max', quantizable_op_type, for_ci=True
        )
1019 1020


W
WangZhen 已提交
1021 1022
if __name__ == '__main__':
    unittest.main()