test_quantization_pass.py 26.4 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

Z
Zhen Wang 已提交
15
import os
W
WangZhen 已提交
16 17 18 19 20
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
W
WangZhen 已提交
21
import paddle
22
from paddle.fluid.framework import IrGraph
23
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
W
WangZhen 已提交
24
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
25 26
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
27
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
W
WangZhen 已提交
28 29
from paddle.fluid import core

Z
Zhen Wang 已提交
30 31 32
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"

W
WangZhen 已提交
33 34 35 36 37 38 39 40 41 42 43 44

def linear_fc(num):
    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        hidden = fluid.layers.fc(hidden, size=128, act='relu')
    loss = fluid.layers.cross_entropy(input=hidden, label=label)
    loss = fluid.layers.mean(loss)
    return loss


45
def residual_block(num, quant_skip_pattern=None):
W
WangZhen 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    def conv_bn_layer(input,
                      ch_out,
                      filter_size,
                      stride,
                      padding,
                      act='relu',
                      bias_attr=False):
        tmp = fluid.layers.conv2d(
            input=input,
            filter_size=filter_size,
            num_filters=ch_out,
            stride=stride,
            padding=padding,
            act=None,
            bias_attr=bias_attr)
        return fluid.layers.batch_norm(input=tmp, act=act)

    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
        short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
        hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
70 71 72 73 74 75 76 77

    if quant_skip_pattern:
        with fluid.name_scope(quant_skip_pattern):
            pool = fluid.layers.pool2d(
                input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
    else:
        pool = fluid.layers.pool2d(
            input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
78
    fc = fluid.layers.fc(input=pool, size=10)
W
WangZhen 已提交
79 80 81 82 83
    loss = fluid.layers.cross_entropy(input=fc, label=label)
    loss = fluid.layers.mean(loss)
    return loss


84
def conv_net(img, label, quant_skip_pattern):
W
WangZhen 已提交
85 86 87 88 89 90
    conv_pool_1 = fluid.nets.simple_img_conv_pool(
        input=img,
        filter_size=5,
        num_filters=20,
        pool_size=2,
        pool_stride=2,
91
        pool_type='max',
W
WangZhen 已提交
92 93 94 95 96 97 98 99
        act="relu")
    conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
    conv_pool_2 = fluid.nets.simple_img_conv_pool(
        input=conv_pool_1,
        filter_size=5,
        num_filters=50,
        pool_size=2,
        pool_stride=2,
100
        pool_type='avg',
W
WangZhen 已提交
101
        act="relu")
102 103 104
    hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
    with fluid.name_scope(quant_skip_pattern):
        prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
W
WangZhen 已提交
105 106 107 108 109
    loss = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_loss = fluid.layers.mean(loss)
    return avg_loss


110
class TestQuantizationTransformPass(unittest.TestCase):
W
WangZhen 已提交
111 112 113 114 115 116
    def setUp(self):
        self.quantizable_op_and_inputs = {
            'conv2d': ['Input', 'Filter'],
            'depthwise_conv2d': ['Input', 'Filter'],
            'mul': ['X', 'Y']
        }
117
        self.quantizable_grad_op_inputs = {
W
WangZhen 已提交
118 119 120 121 122
            'conv2d_grad': ['Input', 'Filter'],
            'depthwise_conv2d_grad': ['Input', 'Filter'],
            'mul_grad': ['X', 'Y']
        }

123
    def check_program(self, program):
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        quantized_ops = set()
        for block in program.blocks:
            for op in block.ops:
                # check forward
                if op.type in self.quantizable_op_and_inputs:
                    for arg_name in op.input_arg_names:
                        self.assertTrue(
                            arg_name.endswith('.quantized.dequantized'))
                        quantized_ops.add(arg_name)

            for op in block.ops:
                # check backward
                if op.type in self.quantizable_grad_op_inputs:
                    for pname in self.quantizable_grad_op_inputs[op.type]:
                        arg_name = op.input(pname)[0]
                        self.assertTrue(
                            arg_name.endswith('.quantized.dequantized'))
                        self.assertTrue(arg_name in quantized_ops)

143 144 145 146
    def linear_fc_quant(self,
                        activation_quant_type,
                        weight_quantize_type,
                        for_ci=True):
W
WangZhen 已提交
147 148 149 150 151 152
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            loss = linear_fc(3)
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
153
        place = fluid.CPUPlace()
154
        graph = IrGraph(core.Graph(main.desc), for_test=False)
155 156
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
157
            place=place,
158 159
            activation_quantize_type=activation_quant_type,
            weight_quantize_type=weight_quantize_type)
160
        transform_pass.apply(graph)
Z
Zhen Wang 已提交
161
        if not for_ci:
Z
Zhen Wang 已提交
162 163 164 165
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
166 167
            graph.draw('.', 'quantize_fc_' + activation_quant_type,
                       marked_nodes)
168
        program = graph.to_program()
169
        self.check_program(program)
170
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
Z
Zhen Wang 已提交
171
        if not for_ci:
Z
Zhen Wang 已提交
172 173 174 175
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
176 177
            val_graph.draw('.', 'val_fc_' + activation_quant_type,
                           val_marked_nodes)
W
WangZhen 已提交
178

179
    def test_linear_fc_quant_abs_max(self):
180
        self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
181

182
    def test_linear_fc_quant_range_abs_max(self):
183
        self.linear_fc_quant('range_abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
184

185
    def test_linear_fc_quant_moving_average_abs_max(self):
186 187
        self.linear_fc_quant(
            'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
188

189 190 191 192
    def residual_block_quant(self,
                             activation_quant_type,
                             weight_quantize_type,
                             for_ci=True):
W
WangZhen 已提交
193 194 195 196 197 198
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            loss = residual_block(2)
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
199
        place = fluid.CPUPlace()
200
        graph = IrGraph(core.Graph(main.desc), for_test=False)
201 202
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
203
            place=place,
204 205
            activation_quantize_type=activation_quant_type,
            weight_quantize_type=weight_quantize_type)
206
        transform_pass.apply(graph)
Z
Zhen Wang 已提交
207
        if not for_ci:
Z
Zhen Wang 已提交
208 209 210 211
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
212 213
            graph.draw('.', 'quantize_residual_' + activation_quant_type,
                       marked_nodes)
214
        program = graph.to_program()
215
        self.check_program(program)
216
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
Z
Zhen Wang 已提交
217
        if not for_ci:
Z
Zhen Wang 已提交
218 219 220 221
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
222 223
            val_graph.draw('.', 'val_residual_' + activation_quant_type,
                           val_marked_nodes)
W
WangZhen 已提交
224

225
    def test_residual_block_abs_max(self):
226
        self.residual_block_quant('abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
227

228
    def test_residual_block_range_abs_max(self):
229
        self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
230

231
    def test_residual_block_moving_average_abs_max(self):
232 233
        self.residual_block_quant(
            'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
234

W
WangZhen 已提交
235

W
WangZhen 已提交
236
class TestQuantizationFreezePass(unittest.TestCase):
237 238 239 240 241
    def freeze_graph(self,
                     use_cuda,
                     seed,
                     activation_quant_type,
                     weight_quant_type='abs_max',
242 243
                     for_ci=True,
                     quant_skip_pattern='skip_quant'):
W
WangZhen 已提交
244 245 246 247 248 249 250 251 252
        def build_program(main, startup, is_test):
            main.random_seed = seed
            startup.random_seed = seed
            with fluid.unique_name.guard():
                with fluid.program_guard(main, startup):
                    img = fluid.layers.data(
                        name='image', shape=[1, 28, 28], dtype='float32')
                    label = fluid.layers.data(
                        name='label', shape=[1], dtype='int64')
253
                    loss = conv_net(img, label, quant_skip_pattern)
W
WangZhen 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
                    if not is_test:
                        opt = fluid.optimizer.Adam(learning_rate=0.001)
                        opt.minimize(loss)
            return [img, label], loss

        random.seed(0)
        np.random.seed(0)

        main = fluid.Program()
        startup = fluid.Program()
        test_program = fluid.Program()
        feeds, loss = build_program(main, startup, False)
        build_program(test_program, startup, True)
        test_program = test_program.clone(for_test=True)
        main_graph = IrGraph(core.Graph(main.desc), for_test=False)
W
WangZhen 已提交
269
        test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
W
WangZhen 已提交
270 271 272

        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
        exe = fluid.Executor(place)
W
WangZhen 已提交
273 274 275
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup)
W
WangZhen 已提交
276
        transform_pass = QuantizationTransformPass(
277 278
            scope=scope,
            place=place,
279
            activation_quantize_type=activation_quant_type,
280 281
            weight_quantize_type=weight_quant_type,
            skip_pattern=quant_skip_pattern)
W
WangZhen 已提交
282 283
        transform_pass.apply(main_graph)
        transform_pass.apply(test_graph)
284
        dev_name = '_gpu_' if use_cuda else '_cpu_'
Z
Zhen Wang 已提交
285
        if not for_ci:
Z
Zhen Wang 已提交
286 287 288 289
            marked_nodes = set()
            for op in main_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
290 291
            main_graph.draw('.', 'main' + dev_name + activation_quant_type + '_'
                            + weight_quant_type, marked_nodes)
Z
Zhen Wang 已提交
292 293 294 295
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
296 297
            test_graph.draw('.', 'test' + dev_name + activation_quant_type + '_'
                            + weight_quant_type, marked_nodes)
W
WangZhen 已提交
298

Z
Zhen Wang 已提交
299 300 301 302 303
        build_strategy = fluid.BuildStrategy()
        build_strategy.memory_optimize = False
        build_strategy.enable_inplace = False
        binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
            loss_name=loss.name, build_strategy=build_strategy)
304
        quantized_test_program = test_graph.to_program()
305
        iters = 5
306
        batch_size = 8
W
WangZhen 已提交
307 308 309 310 311 312 313 314

        train_reader = paddle.batch(
            paddle.reader.shuffle(
                paddle.dataset.mnist.train(), buf_size=500),
            batch_size=batch_size)
        test_reader = paddle.batch(
            paddle.dataset.mnist.test(), batch_size=batch_size)
        feeder = fluid.DataFeeder(feed_list=feeds, place=place)
W
WangZhen 已提交
315
        with fluid.scope_guard(scope):
W
WangZhen 已提交
316 317
            for _ in range(iters):
                data = next(train_reader())
Z
Zhen Wang 已提交
318
                loss_v = exe.run(binary,
319 320
                                 feed=feeder.feed(data),
                                 fetch_list=[loss])
Z
Zhen Wang 已提交
321
                if not for_ci:
322 323 324
                    print('{}: {}'.format('loss' + dev_name +
                                          activation_quant_type + '_' +
                                          weight_quant_type, loss_v))
W
WangZhen 已提交
325

326 327 328 329 330 331 332 333 334 335 336
        test_data = next(test_reader())
        with fluid.program_guard(quantized_test_program):
            w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
                                             quantized_test_program)
        # Testing
        with fluid.scope_guard(scope):
            test_loss1, w_quant = exe.run(program=quantized_test_program,
                                          feed=feeder.feed(test_data),
                                          fetch_list=[loss, w_var])

        # Freeze graph for inference, but the weight of fc/conv is still float type.
337
        freeze_pass = QuantizationFreezePass(
338
            scope=scope, place=place, weight_quantize_type=weight_quant_type)
339
        freeze_pass.apply(test_graph)
Z
Zhen Wang 已提交
340
        if not for_ci:
Z
Zhen Wang 已提交
341 342 343 344
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
345 346
            test_graph.draw('.', 'test_freeze' + dev_name +
                            activation_quant_type + '_' + weight_quant_type,
Z
Zhen Wang 已提交
347
                            marked_nodes)
W
WangZhen 已提交
348

349 350 351 352 353 354
        server_program = test_graph.to_program()
        with fluid.scope_guard(scope):
            test_loss2, = exe.run(program=server_program,
                                  feed=feeder.feed(test_data),
                                  fetch_list=[loss])
        self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
Z
Zhen Wang 已提交
355
        if not for_ci:
356 357 358 359 360 361
            print(
                '{}: {}'.format('test_loss1' + dev_name + activation_quant_type
                                + '_' + weight_quant_type, test_loss1))
            print(
                '{}: {}'.format('test_loss2' + dev_name + activation_quant_type
                                + '_' + weight_quant_type, test_loss2))
362 363
        w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
        # Maybe failed, this is due to the calculation precision
364
        # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
Z
Zhen Wang 已提交
365
        if not for_ci:
366 367 368 369
            print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
                                  + '_' + weight_quant_type, np.sum(w_freeze)))
            print('{}: {}'.format('w_quant' + dev_name + activation_quant_type +
                                  '_' + weight_quant_type, np.sum(w_quant)))
370 371 372 373

        # Convert parameter to 8-bit.
        convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
        convert_int8_pass.apply(test_graph)
Z
Zhen Wang 已提交
374
        if not for_ci:
Z
Zhen Wang 已提交
375 376 377 378
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
379 380
            test_graph.draw('.', 'test_int8' + dev_name + activation_quant_type
                            + '_' + weight_quant_type, marked_nodes)
381 382 383
        server_program_int8 = test_graph.to_program()
        # Save the 8-bit parameter and model file.
        with fluid.scope_guard(scope):
384 385 386 387
            fluid.io.save_inference_model(
                'server_int8' + dev_name + activation_quant_type + '_' +
                weight_quant_type, ['image', 'label'], [loss], exe,
                server_program_int8)
388 389
            # Test whether the 8-bit parameter and model file can be loaded successfully.
            [infer, feed, fetch] = fluid.io.load_inference_model(
390 391
                'server_int8' + dev_name + activation_quant_type + '_' +
                weight_quant_type, exe)
392 393 394 395
        # Check the loaded 8-bit weight.
        w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
        self.assertEqual(w_8bit.dtype, np.int8)
        self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
Z
Zhen Wang 已提交
396
        if not for_ci:
397 398 399 400
            print('{}: {}'.format('w_8bit' + dev_name + activation_quant_type +
                                  '_' + weight_quant_type, np.sum(w_8bit)))
            print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
                                  + '_' + weight_quant_type, np.sum(w_freeze)))
401 402 403

        mobile_pass = TransformForMobilePass()
        mobile_pass.apply(test_graph)
Z
Zhen Wang 已提交
404
        if not for_ci:
Z
Zhen Wang 已提交
405 406 407 408
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
409 410
            test_graph.draw('.', 'test_mobile' + dev_name +
                            activation_quant_type + '_' + weight_quant_type,
Z
Zhen Wang 已提交
411
                            marked_nodes)
412 413 414

        mobile_program = test_graph.to_program()
        with fluid.scope_guard(scope):
415 416 417 418
            fluid.io.save_inference_model(
                'mobile_int8' + dev_name + activation_quant_type + '_' +
                weight_quant_type, ['image', 'label'], [loss], exe,
                mobile_program)
W
WangZhen 已提交
419

420
    def test_freeze_graph_cuda_dynamic(self):
W
WangZhen 已提交
421 422
        if fluid.core.is_compiled_with_cuda():
            with fluid.unique_name.guard():
Z
Zhen Wang 已提交
423
                self.freeze_graph(
424 425 426 427 428 429 430 431 432 433 434 435
                    True,
                    seed=1,
                    activation_quant_type='abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True)
            with fluid.unique_name.guard():
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True)
W
WangZhen 已提交
436

437
    def test_freeze_graph_cpu_dynamic(self):
W
WangZhen 已提交
438
        with fluid.unique_name.guard():
439 440 441 442 443 444 445 446 447 448 449 450
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='abs_max',
                weight_quant_type='abs_max',
                for_ci=True)
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True)
W
WangZhen 已提交
451

452
    def test_freeze_graph_cuda_static(self):
W
WangZhen 已提交
453 454
        if fluid.core.is_compiled_with_cuda():
            with fluid.unique_name.guard():
Z
Zhen Wang 已提交
455
                self.freeze_graph(
456 457 458 459 460 461 462 463 464 465 466
                    True,
                    seed=1,
                    activation_quant_type='range_abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True)
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True)
467 468 469
                self.freeze_graph(
                    True,
                    seed=1,
470 471 472 473 474 475 476 477 478
                    activation_quant_type='range_abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True)
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True)
W
WangZhen 已提交
479

480
    def test_freeze_graph_cpu_static(self):
W
WangZhen 已提交
481
        with fluid.unique_name.guard():
Z
Zhen Wang 已提交
482
            self.freeze_graph(
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='abs_max',
                for_ci=True)
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='abs_max',
                for_ci=True)
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True)
500
            self.freeze_graph(
501 502
                False,
                seed=2,
503 504 505
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True)
W
WangZhen 已提交
506 507


508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
def quant_dequant_residual_block(num, quant_skip_pattern=None):
    def conv_bn_layer(input,
                      ch_out,
                      filter_size,
                      stride,
                      padding,
                      act='relu',
                      bias_attr=False):
        tmp = fluid.layers.conv2d(
            input=input,
            filter_size=filter_size,
            num_filters=ch_out,
            stride=stride,
            padding=padding,
            act=None,
            bias_attr=bias_attr)
        return fluid.layers.batch_norm(input=tmp, act=act)

    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
        short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
        hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')

534
    if isinstance(quant_skip_pattern, str):
535 536 537 538 539 540 541
        with fluid.name_scope(quant_skip_pattern):
            pool1 = fluid.layers.pool2d(
                input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
            pool2 = fluid.layers.pool2d(
                input=hidden, pool_size=2, pool_type='max', pool_stride=2)
            pool_add = fluid.layers.elementwise_add(
                x=pool1, y=pool2, act='relu')
542 543 544 545 546 547 548 549 550 551 552 553
    elif isinstance(quant_skip_pattern, list):
        assert len(
            quant_skip_pattern
        ) > 1, 'test config error: the len of quant_skip_pattern list should be greater than 1.'
        with fluid.name_scope(quant_skip_pattern[0]):
            pool1 = fluid.layers.pool2d(
                input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
            pool2 = fluid.layers.pool2d(
                input=hidden, pool_size=2, pool_type='max', pool_stride=2)
        with fluid.name_scope(quant_skip_pattern[1]):
            pool_add = fluid.layers.elementwise_add(
                x=pool1, y=pool2, act='relu')
554 555 556 557 558 559 560 561 562 563 564 565
    else:
        pool1 = fluid.layers.pool2d(
            input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
        pool2 = fluid.layers.pool2d(
            input=hidden, pool_size=2, pool_type='max', pool_stride=2)
        pool_add = fluid.layers.elementwise_add(x=pool1, y=pool2, act='relu')
    fc = fluid.layers.fc(input=pool_add, size=10)
    loss = fluid.layers.cross_entropy(input=fc, label=label)
    loss = fluid.layers.mean(loss)
    return loss


566 567 568 569 570
class TestAddQuantDequantPass(unittest.TestCase):
    def setUp(self):
        self._target_ops = {'elementwise_add', 'pool2d'}
        self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}

571
    def check_graph(self, graph, skip_pattern=None):
572 573 574
        ops = graph.all_op_nodes()
        for op_node in ops:
            if op_node.name() in self._target_ops:
575 576 577 578 579 580 581 582 583
                user_skipped = False
                if isinstance(skip_pattern, list):
                    user_skipped = op_node.op().has_attr("op_namescope") and \
                                   any(pattern in op_node.op().attr("op_namescope") for pattern in skip_pattern)
                elif isinstance(skip_pattern, str):
                    user_skipped = op_node.op().has_attr("op_namescope") and \
                                   op_node.op().attr("op_namescope").find(skip_pattern) != -1

                if user_skipped:
584 585
                    continue

586 587 588 589 590 591 592 593 594 595 596 597 598
                in_nodes_all_not_persistable = True
                for input_name in op_node.input_arg_names():
                    in_node = graph._find_node_by_name(op_node.inputs,
                                                       input_name)
                    in_nodes_all_not_persistable = (
                        in_nodes_all_not_persistable and
                        not in_node.persistable())
                if not in_nodes_all_not_persistable:
                    continue
                input_names = op_node.input_arg_names()
                for input_name in input_names:
                    self.assertTrue(input_name.endswith('.quant_dequant'))

599
    def residual_block_quant(self, skip_pattern=None, for_ci=True):
600 601 602
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
603
            loss = quant_dequant_residual_block(2, skip_pattern)
604 605 606 607 608
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
        place = fluid.CPUPlace()
        graph = IrGraph(core.Graph(main.desc), for_test=False)
        add_quant_dequant_pass = AddQuantDequantPass(
609
            scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern)
610 611 612 613 614 615 616
        add_quant_dequant_pass.apply(graph)
        if not for_ci:
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quant') > -1:
                    marked_nodes.add(op)
            graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
617
        self.check_graph(graph, skip_pattern)
618 619 620 621 622 623 624 625 626 627
        program = graph.to_program()
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
        if not for_ci:
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quant') > -1:
                    val_marked_nodes.add(op)
            val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)

    def test_residual_block(self):
628 629 630 631
        self.residual_block_quant(skip_pattern=None, for_ci=True)

    def test_residual_block_skip_pattern(self):
        self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
632

633 634 635 636
    def test_residual_block_skip_pattern(self):
        self.residual_block_quant(
            skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True)

637

W
WangZhen 已提交
638 639
if __name__ == '__main__':
    unittest.main()