auto_scan_test.py 28.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest
import abc
18
import os
19
import enum
W
Wilber 已提交
20
import time
W
Wilber 已提交
21
import logging
W
Wilber 已提交
22
import shutil
23 24 25
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import NumpyArrayInitializer
W
Wilber 已提交
26
from paddle.fluid.core import PassVersionChecker
27 28 29 30
import paddle.fluid.core as core
from paddle import compat as cpt
import paddle.inference as paddle_infer
from typing import Optional, List, Callable, Dict, Any, Set
31
from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model
32

W
Wilber 已提交
33
import hypothesis
J
Jason 已提交
34 35
from hypothesis import given, settings, seed, reproduce_failure
import hypothesis.strategies as st
W
Wilber 已提交
36

W
Wilber 已提交
37 38
logging.basicConfig(level=logging.INFO, format="%(message)s")

W
Wilber 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
settings.register_profile(
    "ci",
    max_examples=100,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False)
settings.register_profile(
    "dev",
    max_examples=1000,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False)
if float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1 or \
    os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci':
    settings.load_profile("ci")
else:
    settings.load_profile("dev")

61

62
class IgnoreReasons(enum.Enum):
63 64 65 66
    # Paddle not support, but trt support, we need to add the feature.
    TRT_NOT_IMPLEMENTED = 0
    # TRT not support.
    TRT_NOT_SUPPORT = 1
W
Wilber 已提交
67 68 69 70
    # Accuracy is abnormal after enabling pass.
    PASS_ACCURACY_ERROR = 2
    # Accuracy is abnormal after enabling mkldnn.
    MKLDNN_ACCURACY_ERROR = 3
71 72


73 74 75 76
# TODO(wilber): just for backward compatible
SkipReasons = IgnoreReasons


77
class AutoScanTest(unittest.TestCase):
W
Wilber 已提交
78
    def __init__(self, *args, **kwargs):
W
Wilber 已提交
79
        np.random.seed(1024)
80
        paddle.enable_static()
W
Wilber 已提交
81
        super(AutoScanTest, self).__init__(*args, **kwargs)
82
        self.ignore_cases = []
W
Wilber 已提交
83 84 85
        abs_dir = os.path.abspath(os.path.dirname(__file__))
        self.cache_dir = os.path.join(abs_dir,
                                      str(self.__module__) + '_cache_dir')
J
Jason 已提交
86 87 88
        self.available_passes_in_framework = set()
        self.num_ran_programs = 0
        self.num_invalid_programs = 0
89
        self.num_ignore_tests = 0
J
Jason 已提交
90
        self.num_predictor_kinds = 0
91 92

    @abc.abstractmethod
W
Wilber 已提交
93
    def sample_program_configs(self):
94 95 96 97 98 99 100
        '''
        Generate all config with the combination of different Input tensor shape and
        different Attr values.
        '''
        raise NotImplementedError

    @abc.abstractmethod
W
Wilber 已提交
101
    def sample_predictor_configs(self):
102 103
        raise NotImplementedError

104
    @abc.abstractmethod
105
    def add_ignore_check_case(
106 107
            self,
            teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
108
            reason: IgnoreReasons,
109
            note: str):
110
        self.ignore_cases.append((teller, reason, note))
111

W
Wilber 已提交
112
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
J
Jason 已提交
113
        return True
114

115 116 117 118 119 120 121
    def run_test_config(self, model, params, prog_config, pred_config,
                        feed_data) -> Dict[str, np.ndarray]:
        '''
        Test a single case.
        '''
        pred_config.set_model_buffer(model, len(model), params, len(params))
        predictor = paddle_infer.create_predictor(pred_config)
J
Jason 已提交
122 123
        self.available_passes_in_framework = self.available_passes_in_framework | set(
            pred_config.pass_builder().all_passes())
124 125
        for name, _ in prog_config.inputs.items():
            input_tensor = predictor.get_input_handle(name)
126
            input_tensor.copy_from_cpu(feed_data[name]['data'])
W
Wilber 已提交
127 128
            if feed_data[name]['lod'] is not None:
                input_tensor.set_lod(feed_data[name]['lod'])
129 130
        predictor.run()
        result = {}
131 132 133
        for out_name, o_name in zip(prog_config.outputs,
                                    predictor.get_output_names()):
            result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu()
134 135
        return result

W
Wilber 已提交
136
    @abc.abstractmethod
137
    def assert_tensors_near(self,
W
Wilber 已提交
138 139 140 141 142 143 144 145 146
                            atol: float,
                            rtol: float,
                            tensor: Dict[str, np.array],
                            baseline: Dict[str, np.array]):
        for key, arr in tensor.items():
            self.assertTrue(
                baseline[key].shape == arr.shape,
                "The output shapes are not equal, the baseline shape is " +
                str(baseline[key].shape) + ', but got ' + str(arr.shape))
147
            diff = abs(baseline[key] - arr)
W
Wilber 已提交
148 149 150
            self.assertTrue(
                np.allclose(
                    baseline[key], arr, atol=atol, rtol=rtol),
151 152
                "Output has diff, Maximum absolute error: {}".format(
                    np.amax(diff)))
153

154 155 156
    @abc.abstractmethod
    def run_test(self, quant=False):
        raise NotImplementedError
W
Wilber 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

    def generate_op_config(self,
                           ops_config: List[Dict[str, Any]]) -> List[OpConfig]:
        ops = []
        for i in range(len(ops_config)):
            op_config = ops_config[i]
            ops.append(
                OpConfig(
                    type=op_config['op_type'],
                    inputs=op_config['op_inputs'],
                    outputs=op_config['op_outputs'],
                    attrs=op_config['op_attrs']))
        return ops

    @abc.abstractmethod
172
    def ignore_log(self, msg: str):
W
Wilber 已提交
173 174 175 176
        logging.warning("SKIP: " + msg)

    @abc.abstractmethod
    def fail_log(self, msg: str):
177
        logging.error("FAIL: " + msg)
W
Wilber 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239

    @abc.abstractmethod
    def success_log(self, msg: str):
        logging.info("SUCCESS: " + msg)

    @abc.abstractmethod
    def create_inference_config(self,
                                passes: Optional[List[str]]=None,
                                use_gpu: bool=False,
                                use_mkldnn: bool=False,
                                ir_optim: Optional[bool]=None):
        config = paddle_infer.Config()
        config.switch_ir_debug(True)
        config.set_optim_cache_dir(self.cache_dir)
        config.disable_glog_info()
        if ir_optim is not None:
            config.switch_ir_optim(ir_optim)
        if use_gpu:
            config.enable_use_gpu(100, 0)
        if use_mkldnn:
            config.enable_mkldnn()
        if passes is not None:
            config.pass_builder().set_passes(passes)
            self.passes = passes
        return config


class MkldnnAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
        super(MkldnnAutoScanTest, self).__init__(*args, **kwargs)

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: cpu no ir_optim run
            base_config = self.create_inference_config(ir_optim=False)
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
                self.run_test_config(model, params, prog_config, base_config,
                                     feed_data))
            self.success_log('RUN_CPU_BASELINE done')

            for pred_config, (
                    atol, rtol) in self.sample_predictor_configs(prog_config):
                # skip info
240 241 242 243 244 245 246 247 248 249
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        if ignore_info[
                                1] == IgnoreReasons.MKLDNN_ACCURACY_ERROR:
                            self.ignore_log("[MKLDNN_ACCURACY_ERROR] " +
                                            ignore_info[2] + ' ' + ' vs ' +
                                            self.inference_config_str(
                                                pred_config))
W
Wilber 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
                        self.run_test_config(model, params, prog_config,
                                             pred_config, feed_data))
                    self.assert_tensors_near(atol, rtol, results[-1],
                                             results[0])
                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config) +
                        '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
269
                    if not ignore_flag:
W
Wilber 已提交
270 271 272 273 274 275 276
                        status = False
                    continue
                self.success_log('RUN predictor_config ' + self.
                                 inference_config_str(pred_config) + ' done')

        self.assertTrue(status)

W
Wilber 已提交
277
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)


class PassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
        super(PassAutoScanTest, self).__init__(*args, **kwargs)
        self.passes = []

    def check_op_version(self):
        status = True
        for pass_name in self.passes:
J
Jason 已提交
294 295
            if pass_name not in self.available_passes_in_framework:
                continue
W
Wilber 已提交
296 297 298 299 300
            if not PassVersionChecker.IsCompatible(pass_name):
                self.fail_log('{} version check failed.'.format(pass_name))
                status = False
        return status

301
    def add_ignore_pass_case(self):
J
Jason 已提交
302 303 304
        return

    def assert_op_list(self, op_list_after_fusion):
W
Wilber 已提交
305 306
        if not self.passes:
            raise ValueError(
J
Jason 已提交
307
                "In PassAutoScan you should give a valid pass name.")
W
Wilber 已提交
308
        last_passed_program = os.path.join(self.cache_dir,
J
Jason 已提交
309 310 311 312 313
                                           self.passes[-1] + ".pdmodel")
        if not os.path.exists(last_passed_program):
            raise ValueError(
                "Cannot find file {}, please make sure that your pass name is correct".
                format(last_passed_program))
W
Wilber 已提交
314 315 316
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
J
Jason 已提交
317 318 319 320 321 322 323 324 325
        after_op_list = list()
        for i in range(main_block.op_size()):
            if main_block.op(i).type() in ["feed", "fetch"]:
                continue
            after_op_list.append(main_block.op(i).type())
        self.assertTrue(
            op_list_after_fusion == after_op_list,
            "Expected operator list after fusion is {}, but now it's {}".format(
                op_list_after_fusion, after_op_list), )
W
Wilber 已提交
326

327 328 329 330 331 332
    def run_and_statis(self,
                       quant=False,
                       max_examples=100,
                       reproduce=None,
                       min_success_num=25,
                       max_duration=180,
333
                       passes=None):
J
Jason 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
        if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
            max_examples *= 10
            min_success_num *= 10
            # while at ce phase, there's no limit on time
            max_duration = -1
        start_time = time.time()
        settings.register_profile(
            "ci",
            max_examples=max_examples,
            suppress_health_check=hypothesis.HealthCheck.all(),
            deadline=None,
            print_blob=True,
            derandomize=True,
            report_multiple_bugs=False, )
        settings.load_profile("ci")
        assert passes is not None, "Parameter of passes must be defined in function run_and_statis."
        self.passes = passes

352
        self.add_ignore_pass_case()
J
Jason 已提交
353 354 355 356 357

        def program_generator(draw):
            return self.sample_program_config(draw)

        def run_test(prog_config):
358
            return self.run_test(quant=quant, prog_configs=[prog_config])
J
Jason 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372

        generator = st.composite(program_generator)
        loop_func = given(generator())(run_test)
        if reproduce is not None:
            loop_func = reproduce(loop_func)
        logging.info("Start to running test of {}".format(type(self)))
        loop_func()
        logging.info(
            "===================Statistical Information===================")
        logging.info("Number of Generated Programs: {}".format(
            self.num_ran_programs + self.num_invalid_programs))
        logging.info("Number of Invalid Programs: {}".format(
            self.num_invalid_programs))
        logging.info("Number of Ran Programs: {}".format(self.num_ran_programs))
373
        logging.info("Number of Ignore Tests: {}".format(self.num_ignore_tests))
J
Jason 已提交
374
        successful_ran_programs = int(self.num_ran_programs -
375 376
                                      self.num_ignore_tests / max(
                                          self.num_predictor_kinds, 1))
J
Jason 已提交
377 378 379 380 381
        logging.info(
            "Number of successfully ran programs approximately equal to {}".
            format(successful_ran_programs))
        if successful_ran_programs < min_success_num:
            logging.warning(
382
                "satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
J
Jason 已提交
383 384 385 386 387 388 389 390 391 392 393 394
            )
            logging.error(
                "At least {} programs need to ran successfully, but now only about {} programs satisfied.".
                format(min_success_num, successful_ran_programs))
            assert False
        used_time = time.time() - start_time
        if max_duration > 0 and used_time > max_duration:
            logging.error(
                "The duration exceeds {} seconds, if this is neccessary, try to set a larger number for parameter `max_duration`.".
                format(max_duration))
            assert False

395
    def run_test(self, quant=False, prog_configs=None):
W
Wilber 已提交
396 397
        status = True

J
Jason 已提交
398
        for prog_config in prog_configs:
W
Wilber 已提交
399 400
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
J
Jason 已提交
401
                self.num_invalid_programs += 1
W
Wilber 已提交
402
                continue
J
Jason 已提交
403
            self.num_ran_programs += 1
W
Wilber 已提交
404 405 406 407 408 409 410 411 412 413
            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod
                }
414

W
Wilber 已提交
415
            logging.info('RUN program_config: ' + str(prog_config))
J
Jason 已提交
416 417
            self.num_predictor_kinds = 0
            for pred_config, op_list, (
W
Wilber 已提交
418
                    atol, rtol) in self.sample_predictor_configs(prog_config):
J
Jason 已提交
419
                self.num_predictor_kinds += 1
420

W
Wilber 已提交
421
                # skip info
422 423 424 425 426 427 428 429 430 431
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        self.num_ignore_tests += 1
                        if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
                            self.ignore_log("[PASS_ACCURACY_ERROR] " +
                                            ignore_info[2] + ' ' + ' vs ' +
                                            self.inference_config_str(
                                                pred_config))
W
Wilber 已提交
432 433 434 435 436 437 438 439 440
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

441 442
                # baseline: no ir_optim run
                base_config = self.create_inference_config(
H
heliqi 已提交
443
                    ir_optim=False, use_gpu=pred_config.use_gpu())
W
Wilber 已提交
444
                try:
445 446 447 448 449 450 451 452 453 454 455 456 457 458
                    # baseline
                    base_result = self.run_test_config(
                        model, params, prog_config, base_config, feed_data)
                    self.success_log('RUN_BASELINE ' +
                                     self.inference_config_str(
                                         base_config) + ' done')

                    if os.path.exists(self.cache_dir):
                        shutil.rmtree(self.cache_dir)

                    pred_result = self.run_test_config(
                        model, params, prog_config, pred_config, feed_data)
                    self.assert_tensors_near(atol, rtol, pred_result,
                                             base_result)
459
                    if not ignore_flag:
J
Jason 已提交
460
                        self.assert_op_list(op_list)
W
Wilber 已提交
461 462 463 464 465

                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config) +
                        '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
466
                    if not ignore_flag:
W
Wilber 已提交
467 468 469 470 471 472 473 474
                        status = False
                    continue
                self.success_log('RUN predictor_config ' + self.
                                 inference_config_str(pred_config) + ' done')

        status = self.check_op_version() and status
        self.assertTrue(status)

W
Wilber 已提交
475
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        if not self.passes:
            dic['passes'] = self.passes

        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

    def create_trt_inference_config(self) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        config.switch_ir_debug()
        return config


class TrtLayerAutoScanTest(AutoScanTest):
    class TensorRTParam:
        '''
        TensorRT subgraph engine parameters. 
        '''

        def __init__(self, workspace_size, max_batch_size, min_subgraph_size,
                     precision, use_static, use_calib_mode):
            self.workspace_size = workspace_size
            self.max_batch_size = max_batch_size
            self.min_subgraph_size = min_subgraph_size
            self.precision = precision
            self.use_static = use_static
            self.use_calib_mode = use_calib_mode

    class DynamicShapeParam:
        '''
         Prepare TensorRT subgraph engine dynamic shape parameters. 
         '''

        def __init__(self, min_input_shape, max_input_shape, opt_input_shape,
                     disable_trt_plugin_fp16):
            self.min_input_shape = min_input_shape
            self.max_input_shape = max_input_shape
            self.opt_input_shape = opt_input_shape
            self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16

    def __init__(self, *args, **kwargs):
        super(TrtLayerAutoScanTest, self).__init__(*args, **kwargs)
        self.trt_param = self.TensorRTParam(
            workspace_size=1024,
            max_batch_size=4,
            min_subgraph_size=0,
            precision=paddle_infer.PrecisionType.Float32,
            use_static=True,
            use_calib_mode=False)
        self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
        self.num_percent_cases = float(
            os.getenv(
                'TEST_NUM_PERCENT_CASES', default='1.0'))
        # Choose different tests by week
        np.random.seed(int(time.strftime("%W")))

    def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        if use_trt:
            config.switch_ir_debug()
            config.enable_tensorrt_engine(
                max_batch_size=self.trt_param.max_batch_size,
                workspace_size=self.trt_param.workspace_size,
                min_subgraph_size=self.trt_param.min_subgraph_size,
                precision_mode=self.trt_param.precision,
                use_static=self.trt_param.use_static,
                use_calib_mode=self.trt_param.use_calib_mode)
            if len(self.dynamic_shape.min_input_shape
                   ) != 0 and self.dynamic_shape.min_input_shape.keys(
                   ) == self.dynamic_shape.max_input_shape.keys(
                   ) and self.dynamic_shape.min_input_shape.keys(
                   ) == self.dynamic_shape.opt_input_shape.keys():
                config.set_trt_dynamic_shape_info(
                    self.dynamic_shape.min_input_shape,
                    self.dynamic_shape.max_input_shape,
                    self.dynamic_shape.opt_input_shape,
                    self.dynamic_shape.disable_trt_plugin_fp16)
        return config

    def assert_op_size(self, trt_engine_num, paddle_op_num):
        last_passed_program = os.path.join(
            self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel')
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
        op_size = main_block.op_size()
        op_types = [
            main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size)
        ]
        trt_engine_size = sum(op_types)
        paddle_op_size = op_size - trt_engine_size
        self.assertTrue(trt_engine_size == trt_engine_num,
                        'trt_engine_num is {}, but got {}!'.format(
                            trt_engine_size, trt_engine_num))
        self.assertTrue(paddle_op_size == paddle_op_num,
                        'paddle_op_num is {}, but got {}!'.format(
                            paddle_op_size, paddle_op_num))

W
Wilber 已提交
592
    def inference_config_str(self, config: paddle_infer.Config) -> str:
W
Wilber 已提交
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
        dic = {}
        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

    def run_test(self, quant=False, *args, **kwargs):
        status = True
        run_flags = []
        for prog_config in self.sample_program_configs(*args, **kwargs):
            # In CI, only run 10% cases
            if np.random.rand() < self.num_percent_cases:
                run_flags.append(True)
            else:
                run_flags.append(False)

        for prog_config, run_flags in zip(
                self.sample_program_configs(*args, **kwargs), run_flags):
            if not run_flags:
                continue

            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod
                }

            results: List[Dict[str, np.ndarray]] = []

            # baseline: gpu run
            logging.info('RUN program_config: ' + str(prog_config))
            gpu_config = self.create_inference_config(use_trt=False)
            results.append(
                self.run_test_config(model, params, prog_config, gpu_config,
                                     feed_data))
            self.success_log('RUN_GPU_BASELINE done')

            for pred_config, nodes_num, threshold in self.sample_predictor_configs(
                    prog_config):

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)

                if isinstance(threshold, float):
                    atol = threshold
                    rtol = 1e-8
                elif isinstance(threshold, list) or isinstance(threshold,
                                                               tuple):
                    atol = threshold[0]
                    rtol = threshold[1]
                else:
                    raise NotImplementedError

                if quant and pred_config.tensorrt_precision_mode(
                ) != paddle_infer.PrecisionType.Int8:
                    continue
                if pred_config.tensorrt_precision_mode(
                ) == paddle_infer.PrecisionType.Int8 and not quant:
                    continue

668 669 670 671 672 673 674 675 676 677 678
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        if ignore_info[1] == IgnoreReasons.TRT_NOT_IMPLEMENTED:
                            self.ignore_log("[TRT_NOT_IMPLEMENTED] " +
                                            ignore_info[2] + ' ' + ' vs ' +
                                            self.inference_config_str(
                                                pred_config))
                        elif ignore_info[1] == IgnoreReasons.TRT_NOT_SUPPORT:
                            self.ignore_log("[TRT_NOT_SUPPORT] " + ignore_info[
W
Wilber 已提交
679 680 681 682 683 684 685 686 687 688 689 690 691
                                2] + ' ' + ' vs ' + self.inference_config_str(
                                    pred_config))
                        else:
                            raise NotImplementedError
                        break

                try:
                    pred_config_deserialize = paddle_infer.Config(pred_config)
                    results.append(
                        self.run_test_config(model, params, prog_config,
                                             pred_config, feed_data))
                    self.assert_tensors_near(atol, rtol, results[-1],
                                             results[0])
692
                    if not ignore_flag:
W
Wilber 已提交
693 694 695 696 697 698 699 700 701 702
                        self.assert_op_size(nodes_num[0], nodes_num[1])
                    # deserialize test
                    if nodes_num[0] > 0:
                        self.run_test_config(model, params, prog_config,
                                             pred_config_deserialize, feed_data)
                except Exception as e:
                    self.fail_log(
                        str(prog_config) + ' vs ' + self.inference_config_str(
                            pred_config) +
                        '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
703
                    if not ignore_flag:
W
Wilber 已提交
704 705 706 707 708 709
                        status = False
                    continue
                self.success_log('RUN predictor_config ' + self.
                                 inference_config_str(pred_config) + ' done')

        self.assertTrue(status)
710 711 712 713 714 715 716 717

    # TODO(wilber): just for backward compatible
    def add_skip_case(
            self,
            teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
            reason: IgnoreReasons,
            note: str):
        self.ignore_cases.append((teller, reason, note))