test_quantization_pass.py 23.4 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

Z
Zhen Wang 已提交
15
import os
W
WangZhen 已提交
16 17 18 19 20
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
W
WangZhen 已提交
21
import paddle
22
from paddle.fluid.framework import IrGraph
23
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
W
WangZhen 已提交
24
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
25 26
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
27
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
W
WangZhen 已提交
28 29
from paddle.fluid import core

Z
Zhen Wang 已提交
30 31 32
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"

W
WangZhen 已提交
33 34 35 36 37 38 39 40 41 42 43 44

def linear_fc(num):
    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        hidden = fluid.layers.fc(hidden, size=128, act='relu')
    loss = fluid.layers.cross_entropy(input=hidden, label=label)
    loss = fluid.layers.mean(loss)
    return loss


45
def residual_block(num, quant_skip_pattern=None):
W
WangZhen 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    def conv_bn_layer(input,
                      ch_out,
                      filter_size,
                      stride,
                      padding,
                      act='relu',
                      bias_attr=False):
        tmp = fluid.layers.conv2d(
            input=input,
            filter_size=filter_size,
            num_filters=ch_out,
            stride=stride,
            padding=padding,
            act=None,
            bias_attr=bias_attr)
        return fluid.layers.batch_norm(input=tmp, act=act)

    data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    hidden = data
    for _ in six.moves.xrange(num):
        conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
        short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
        hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
70 71 72 73 74 75 76 77

    if quant_skip_pattern:
        with fluid.name_scope(quant_skip_pattern):
            pool = fluid.layers.pool2d(
                input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
    else:
        pool = fluid.layers.pool2d(
            input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
78
    fc = fluid.layers.fc(input=pool, size=10)
W
WangZhen 已提交
79 80 81 82 83
    loss = fluid.layers.cross_entropy(input=fc, label=label)
    loss = fluid.layers.mean(loss)
    return loss


84
def conv_net(img, label, quant_skip_pattern):
W
WangZhen 已提交
85 86 87 88 89 90
    conv_pool_1 = fluid.nets.simple_img_conv_pool(
        input=img,
        filter_size=5,
        num_filters=20,
        pool_size=2,
        pool_stride=2,
91
        pool_type='max',
W
WangZhen 已提交
92 93 94 95 96 97 98 99
        act="relu")
    conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
    conv_pool_2 = fluid.nets.simple_img_conv_pool(
        input=conv_pool_1,
        filter_size=5,
        num_filters=50,
        pool_size=2,
        pool_stride=2,
100
        pool_type='avg',
W
WangZhen 已提交
101
        act="relu")
102 103 104
    hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
    with fluid.name_scope(quant_skip_pattern):
        prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
W
WangZhen 已提交
105 106 107 108 109
    loss = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_loss = fluid.layers.mean(loss)
    return avg_loss


110
class TestQuantizationTransformPass(unittest.TestCase):
W
WangZhen 已提交
111 112 113 114 115 116
    def setUp(self):
        self.quantizable_op_and_inputs = {
            'conv2d': ['Input', 'Filter'],
            'depthwise_conv2d': ['Input', 'Filter'],
            'mul': ['X', 'Y']
        }
117
        self.quantizable_grad_op_inputs = {
W
WangZhen 已提交
118 119 120 121 122
            'conv2d_grad': ['Input', 'Filter'],
            'depthwise_conv2d_grad': ['Input', 'Filter'],
            'mul_grad': ['X', 'Y']
        }

123
    def check_program(self, program):
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        quantized_ops = set()
        for block in program.blocks:
            for op in block.ops:
                # check forward
                if op.type in self.quantizable_op_and_inputs:
                    for arg_name in op.input_arg_names:
                        self.assertTrue(
                            arg_name.endswith('.quantized.dequantized'))
                        quantized_ops.add(arg_name)

            for op in block.ops:
                # check backward
                if op.type in self.quantizable_grad_op_inputs:
                    for pname in self.quantizable_grad_op_inputs[op.type]:
                        arg_name = op.input(pname)[0]
                        self.assertTrue(
                            arg_name.endswith('.quantized.dequantized'))
                        self.assertTrue(arg_name in quantized_ops)

143 144 145 146
    def linear_fc_quant(self,
                        activation_quant_type,
                        weight_quantize_type,
                        for_ci=True):
W
WangZhen 已提交
147 148 149 150 151 152
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            loss = linear_fc(3)
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
153
        place = fluid.CPUPlace()
154
        graph = IrGraph(core.Graph(main.desc), for_test=False)
155 156
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
157
            place=place,
158 159
            activation_quantize_type=activation_quant_type,
            weight_quantize_type=weight_quantize_type)
160
        transform_pass.apply(graph)
Z
Zhen Wang 已提交
161
        if not for_ci:
Z
Zhen Wang 已提交
162 163 164 165
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
166 167
            graph.draw('.', 'quantize_fc_' + activation_quant_type,
                       marked_nodes)
168
        program = graph.to_program()
169
        self.check_program(program)
170
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
Z
Zhen Wang 已提交
171
        if not for_ci:
Z
Zhen Wang 已提交
172 173 174 175
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
176 177
            val_graph.draw('.', 'val_fc_' + activation_quant_type,
                           val_marked_nodes)
W
WangZhen 已提交
178

179
    def test_linear_fc_quant_abs_max(self):
180
        self.linear_fc_quant('abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
181

182
    def test_linear_fc_quant_range_abs_max(self):
183
        self.linear_fc_quant('range_abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
184

185
    def test_linear_fc_quant_moving_average_abs_max(self):
186 187
        self.linear_fc_quant(
            'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
188

189 190 191 192
    def residual_block_quant(self,
                             activation_quant_type,
                             weight_quantize_type,
                             for_ci=True):
W
WangZhen 已提交
193 194 195 196 197 198
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
            loss = residual_block(2)
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
199
        place = fluid.CPUPlace()
200
        graph = IrGraph(core.Graph(main.desc), for_test=False)
201 202
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
203
            place=place,
204 205
            activation_quantize_type=activation_quant_type,
            weight_quantize_type=weight_quantize_type)
206
        transform_pass.apply(graph)
Z
Zhen Wang 已提交
207
        if not for_ci:
Z
Zhen Wang 已提交
208 209 210 211
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
212 213
            graph.draw('.', 'quantize_residual_' + activation_quant_type,
                       marked_nodes)
214
        program = graph.to_program()
215
        self.check_program(program)
216
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
Z
Zhen Wang 已提交
217
        if not for_ci:
Z
Zhen Wang 已提交
218 219 220 221
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    val_marked_nodes.add(op)
222 223
            val_graph.draw('.', 'val_residual_' + activation_quant_type,
                           val_marked_nodes)
W
WangZhen 已提交
224

225
    def test_residual_block_abs_max(self):
226
        self.residual_block_quant('abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
227

228
    def test_residual_block_range_abs_max(self):
229
        self.residual_block_quant('range_abs_max', 'abs_max', for_ci=True)
W
WangZhen 已提交
230

231
    def test_residual_block_moving_average_abs_max(self):
232 233
        self.residual_block_quant(
            'moving_average_abs_max', 'channel_wise_abs_max', for_ci=True)
234

W
WangZhen 已提交
235

W
WangZhen 已提交
236
class TestQuantizationFreezePass(unittest.TestCase):
237 238 239 240 241
    def freeze_graph(self,
                     use_cuda,
                     seed,
                     activation_quant_type,
                     weight_quant_type='abs_max',
242 243
                     for_ci=True,
                     quant_skip_pattern='skip_quant'):
W
WangZhen 已提交
244 245 246 247 248 249 250 251 252
        def build_program(main, startup, is_test):
            main.random_seed = seed
            startup.random_seed = seed
            with fluid.unique_name.guard():
                with fluid.program_guard(main, startup):
                    img = fluid.layers.data(
                        name='image', shape=[1, 28, 28], dtype='float32')
                    label = fluid.layers.data(
                        name='label', shape=[1], dtype='int64')
253
                    loss = conv_net(img, label, quant_skip_pattern)
W
WangZhen 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
                    if not is_test:
                        opt = fluid.optimizer.Adam(learning_rate=0.001)
                        opt.minimize(loss)
            return [img, label], loss

        random.seed(0)
        np.random.seed(0)

        main = fluid.Program()
        startup = fluid.Program()
        test_program = fluid.Program()
        feeds, loss = build_program(main, startup, False)
        build_program(test_program, startup, True)
        test_program = test_program.clone(for_test=True)
        main_graph = IrGraph(core.Graph(main.desc), for_test=False)
W
WangZhen 已提交
269
        test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
W
WangZhen 已提交
270 271 272

        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
        exe = fluid.Executor(place)
W
WangZhen 已提交
273 274 275
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            exe.run(startup)
W
WangZhen 已提交
276
        transform_pass = QuantizationTransformPass(
277 278
            scope=scope,
            place=place,
279
            activation_quantize_type=activation_quant_type,
280 281
            weight_quantize_type=weight_quant_type,
            skip_pattern=quant_skip_pattern)
W
WangZhen 已提交
282 283
        transform_pass.apply(main_graph)
        transform_pass.apply(test_graph)
284
        dev_name = '_gpu_' if use_cuda else '_cpu_'
Z
Zhen Wang 已提交
285
        if not for_ci:
Z
Zhen Wang 已提交
286 287 288 289
            marked_nodes = set()
            for op in main_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
290 291
            main_graph.draw('.', 'main' + dev_name + activation_quant_type + '_'
                            + weight_quant_type, marked_nodes)
Z
Zhen Wang 已提交
292 293 294 295
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
296 297
            test_graph.draw('.', 'test' + dev_name + activation_quant_type + '_'
                            + weight_quant_type, marked_nodes)
W
WangZhen 已提交
298

Z
Zhen Wang 已提交
299 300 301 302 303
        build_strategy = fluid.BuildStrategy()
        build_strategy.memory_optimize = False
        build_strategy.enable_inplace = False
        binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
            loss_name=loss.name, build_strategy=build_strategy)
304
        quantized_test_program = test_graph.to_program()
305
        iters = 5
306
        batch_size = 8
W
WangZhen 已提交
307 308 309 310 311 312 313 314

        train_reader = paddle.batch(
            paddle.reader.shuffle(
                paddle.dataset.mnist.train(), buf_size=500),
            batch_size=batch_size)
        test_reader = paddle.batch(
            paddle.dataset.mnist.test(), batch_size=batch_size)
        feeder = fluid.DataFeeder(feed_list=feeds, place=place)
W
WangZhen 已提交
315
        with fluid.scope_guard(scope):
W
WangZhen 已提交
316 317
            for _ in range(iters):
                data = next(train_reader())
Z
Zhen Wang 已提交
318
                loss_v = exe.run(binary,
319 320
                                 feed=feeder.feed(data),
                                 fetch_list=[loss])
Z
Zhen Wang 已提交
321
                if not for_ci:
322 323 324
                    print('{}: {}'.format('loss' + dev_name +
                                          activation_quant_type + '_' +
                                          weight_quant_type, loss_v))
W
WangZhen 已提交
325

326 327 328 329 330 331 332 333 334 335 336
        test_data = next(test_reader())
        with fluid.program_guard(quantized_test_program):
            w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
                                             quantized_test_program)
        # Testing
        with fluid.scope_guard(scope):
            test_loss1, w_quant = exe.run(program=quantized_test_program,
                                          feed=feeder.feed(test_data),
                                          fetch_list=[loss, w_var])

        # Freeze graph for inference, but the weight of fc/conv is still float type.
337
        freeze_pass = QuantizationFreezePass(
338
            scope=scope, place=place, weight_quantize_type=weight_quant_type)
339
        freeze_pass.apply(test_graph)
Z
Zhen Wang 已提交
340
        if not for_ci:
Z
Zhen Wang 已提交
341 342 343 344
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
345 346
            test_graph.draw('.', 'test_freeze' + dev_name +
                            activation_quant_type + '_' + weight_quant_type,
Z
Zhen Wang 已提交
347
                            marked_nodes)
W
WangZhen 已提交
348

349 350 351 352 353 354
        server_program = test_graph.to_program()
        with fluid.scope_guard(scope):
            test_loss2, = exe.run(program=server_program,
                                  feed=feeder.feed(test_data),
                                  fetch_list=[loss])
        self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
Z
Zhen Wang 已提交
355
        if not for_ci:
356 357 358 359 360 361
            print(
                '{}: {}'.format('test_loss1' + dev_name + activation_quant_type
                                + '_' + weight_quant_type, test_loss1))
            print(
                '{}: {}'.format('test_loss2' + dev_name + activation_quant_type
                                + '_' + weight_quant_type, test_loss2))
362 363
        w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
        # Maybe failed, this is due to the calculation precision
364
        # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
Z
Zhen Wang 已提交
365
        if not for_ci:
366 367 368 369
            print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
                                  + '_' + weight_quant_type, np.sum(w_freeze)))
            print('{}: {}'.format('w_quant' + dev_name + activation_quant_type +
                                  '_' + weight_quant_type, np.sum(w_quant)))
370 371 372 373

        # Convert parameter to 8-bit.
        convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
        convert_int8_pass.apply(test_graph)
Z
Zhen Wang 已提交
374
        if not for_ci:
Z
Zhen Wang 已提交
375 376 377 378
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
379 380
            test_graph.draw('.', 'test_int8' + dev_name + activation_quant_type
                            + '_' + weight_quant_type, marked_nodes)
381 382 383
        server_program_int8 = test_graph.to_program()
        # Save the 8-bit parameter and model file.
        with fluid.scope_guard(scope):
384 385 386 387
            fluid.io.save_inference_model(
                'server_int8' + dev_name + activation_quant_type + '_' +
                weight_quant_type, ['image', 'label'], [loss], exe,
                server_program_int8)
388 389
            # Test whether the 8-bit parameter and model file can be loaded successfully.
            [infer, feed, fetch] = fluid.io.load_inference_model(
390 391
                'server_int8' + dev_name + activation_quant_type + '_' +
                weight_quant_type, exe)
392 393 394 395
        # Check the loaded 8-bit weight.
        w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
        self.assertEqual(w_8bit.dtype, np.int8)
        self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
Z
Zhen Wang 已提交
396
        if not for_ci:
397 398 399 400
            print('{}: {}'.format('w_8bit' + dev_name + activation_quant_type +
                                  '_' + weight_quant_type, np.sum(w_8bit)))
            print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
                                  + '_' + weight_quant_type, np.sum(w_freeze)))
401 402 403

        mobile_pass = TransformForMobilePass()
        mobile_pass.apply(test_graph)
Z
Zhen Wang 已提交
404
        if not for_ci:
Z
Zhen Wang 已提交
405 406 407 408
            marked_nodes = set()
            for op in test_graph.all_op_nodes():
                if op.name().find('quantize') > -1:
                    marked_nodes.add(op)
409 410
            test_graph.draw('.', 'test_mobile' + dev_name +
                            activation_quant_type + '_' + weight_quant_type,
Z
Zhen Wang 已提交
411
                            marked_nodes)
412 413 414

        mobile_program = test_graph.to_program()
        with fluid.scope_guard(scope):
415 416 417 418
            fluid.io.save_inference_model(
                'mobile_int8' + dev_name + activation_quant_type + '_' +
                weight_quant_type, ['image', 'label'], [loss], exe,
                mobile_program)
W
WangZhen 已提交
419

420
    def test_freeze_graph_cuda_dynamic(self):
W
WangZhen 已提交
421 422
        if fluid.core.is_compiled_with_cuda():
            with fluid.unique_name.guard():
Z
Zhen Wang 已提交
423
                self.freeze_graph(
424 425 426 427 428 429 430 431 432 433 434 435
                    True,
                    seed=1,
                    activation_quant_type='abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True)
            with fluid.unique_name.guard():
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True)
W
WangZhen 已提交
436

437
    def test_freeze_graph_cpu_dynamic(self):
W
WangZhen 已提交
438
        with fluid.unique_name.guard():
439 440 441 442 443 444 445 446 447 448 449 450
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='abs_max',
                weight_quant_type='abs_max',
                for_ci=True)
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True)
W
WangZhen 已提交
451

452
    def test_freeze_graph_cuda_static(self):
W
WangZhen 已提交
453 454
        if fluid.core.is_compiled_with_cuda():
            with fluid.unique_name.guard():
Z
Zhen Wang 已提交
455
                self.freeze_graph(
456 457 458 459 460 461 462 463 464 465 466
                    True,
                    seed=1,
                    activation_quant_type='range_abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True)
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    weight_quant_type='abs_max',
                    for_ci=True)
467 468 469
                self.freeze_graph(
                    True,
                    seed=1,
470 471 472 473 474 475 476 477 478
                    activation_quant_type='range_abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True)
                self.freeze_graph(
                    True,
                    seed=1,
                    activation_quant_type='moving_average_abs_max',
                    weight_quant_type='channel_wise_abs_max',
                    for_ci=True)
W
WangZhen 已提交
479

480
    def test_freeze_graph_cpu_static(self):
W
WangZhen 已提交
481
        with fluid.unique_name.guard():
Z
Zhen Wang 已提交
482
            self.freeze_graph(
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='abs_max',
                for_ci=True)
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='abs_max',
                for_ci=True)
            self.freeze_graph(
                False,
                seed=2,
                activation_quant_type='range_abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True)
500
            self.freeze_graph(
501 502
                False,
                seed=2,
503 504 505
                activation_quant_type='moving_average_abs_max',
                weight_quant_type='channel_wise_abs_max',
                for_ci=True)
W
WangZhen 已提交
506 507


508 509 510 511 512
class TestAddQuantDequantPass(unittest.TestCase):
    def setUp(self):
        self._target_ops = {'elementwise_add', 'pool2d'}
        self._target_grad_ops = {'elementwise_add_grad', 'pool2d_grad'}

513
    def check_graph(self, graph, skip_pattern=None):
514 515 516
        ops = graph.all_op_nodes()
        for op_node in ops:
            if op_node.name() in self._target_ops:
517 518 519 520
                if skip_pattern and op_node.op().has_attr("op_namescope") and \
                    op_node.op().attr("op_namescope").find(skip_pattern) != -1:
                    continue

521 522 523 524 525 526 527 528 529 530 531 532 533
                in_nodes_all_not_persistable = True
                for input_name in op_node.input_arg_names():
                    in_node = graph._find_node_by_name(op_node.inputs,
                                                       input_name)
                    in_nodes_all_not_persistable = (
                        in_nodes_all_not_persistable and
                        not in_node.persistable())
                if not in_nodes_all_not_persistable:
                    continue
                input_names = op_node.input_arg_names()
                for input_name in input_names:
                    self.assertTrue(input_name.endswith('.quant_dequant'))

534
    def residual_block_quant(self, skip_pattern=None, for_ci=True):
535 536 537
        main = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(main, startup):
538
            loss = residual_block(2, skip_pattern)
539 540 541 542 543 544 545 546 547 548 549 550 551
            opt = fluid.optimizer.Adam(learning_rate=0.001)
            opt.minimize(loss)
        place = fluid.CPUPlace()
        graph = IrGraph(core.Graph(main.desc), for_test=False)
        add_quant_dequant_pass = AddQuantDequantPass(
            scope=fluid.global_scope(), place=place)
        add_quant_dequant_pass.apply(graph)
        if not for_ci:
            marked_nodes = set()
            for op in graph.all_op_nodes():
                if op.name().find('quant') > -1:
                    marked_nodes.add(op)
            graph.draw('.', 'add_quant_dequant_graph', marked_nodes)
552
        self.check_graph(graph, skip_pattern)
553 554 555 556 557 558 559 560 561 562
        program = graph.to_program()
        val_graph = IrGraph(core.Graph(program.desc), for_test=False)
        if not for_ci:
            val_marked_nodes = set()
            for op in val_graph.all_op_nodes():
                if op.name().find('quant') > -1:
                    val_marked_nodes.add(op)
            val_graph.draw('.', 'val_add_quant_dequant_graph', val_marked_nodes)

    def test_residual_block(self):
563 564 565 566
        self.residual_block_quant(skip_pattern=None, for_ci=True)

    def test_residual_block_skip_pattern(self):
        self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
567 568


W
WangZhen 已提交
569 570
if __name__ == '__main__':
    unittest.main()