auto_scan_test.py 29.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest
import abc
18
import os
19
import enum
W
Wilber 已提交
20
import time
W
Wilber 已提交
21
import logging
W
Wilber 已提交
22
import shutil
23
import paddle
W
Wilber 已提交
24
from paddle.fluid.core import PassVersionChecker
25
import paddle.inference as paddle_infer
26 27
from typing import Optional, List, Callable, Dict, Any
from program_config import OpConfig, ProgramConfig, create_fake_model, create_quant_model
28

W
Wilber 已提交
29
import hypothesis
30
from hypothesis import given, settings
J
Jason 已提交
31
import hypothesis.strategies as st
W
Wilber 已提交
32

W
Wilber 已提交
33 34
logging.basicConfig(level=logging.INFO, format="%(message)s")

35 36 37 38 39 40 41 42 43 44 45 46 47 48
settings.register_profile("ci",
                          max_examples=100,
                          suppress_health_check=hypothesis.HealthCheck.all(),
                          deadline=None,
                          print_blob=True,
                          derandomize=True,
                          report_multiple_bugs=False)
settings.register_profile("dev",
                          max_examples=1000,
                          suppress_health_check=hypothesis.HealthCheck.all(),
                          deadline=None,
                          print_blob=True,
                          derandomize=True,
                          report_multiple_bugs=False)
W
Wilber 已提交
49 50 51 52 53 54
if float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1 or \
    os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci':
    settings.load_profile("ci")
else:
    settings.load_profile("dev")

55

56
class IgnoreReasons(enum.Enum):
57 58 59 60
    # Paddle not support, but trt support, we need to add the feature.
    TRT_NOT_IMPLEMENTED = 0
    # TRT not support.
    TRT_NOT_SUPPORT = 1
W
Wilber 已提交
61 62 63 64
    # Accuracy is abnormal after enabling pass.
    PASS_ACCURACY_ERROR = 2
    # Accuracy is abnormal after enabling mkldnn.
    MKLDNN_ACCURACY_ERROR = 3
65 66


67 68 69 70
# TODO(wilber): just for backward compatible
SkipReasons = IgnoreReasons


71
class AutoScanTest(unittest.TestCase):
72

W
Wilber 已提交
73
    def __init__(self, *args, **kwargs):
W
Wilber 已提交
74
        np.random.seed(1024)
75
        paddle.enable_static()
W
Wilber 已提交
76
        super(AutoScanTest, self).__init__(*args, **kwargs)
77
        self.ignore_cases = []
W
Wilber 已提交
78 79 80
        abs_dir = os.path.abspath(os.path.dirname(__file__))
        self.cache_dir = os.path.join(abs_dir,
                                      str(self.__module__) + '_cache_dir')
J
Jason 已提交
81 82 83
        self.available_passes_in_framework = set()
        self.num_ran_programs = 0
        self.num_invalid_programs = 0
84
        self.num_ignore_tests = 0
J
Jason 已提交
85
        self.num_predictor_kinds = 0
86 87

    @abc.abstractmethod
W
Wilber 已提交
88
    def sample_program_configs(self):
89 90 91 92 93 94 95
        '''
        Generate all config with the combination of different Input tensor shape and
        different Attr values.
        '''
        raise NotImplementedError

    @abc.abstractmethod
W
Wilber 已提交
96
    def sample_predictor_configs(self):
97 98
        raise NotImplementedError

99
    @abc.abstractmethod
100 101 102
    def add_ignore_check_case(self, teller: [
        Callable[[ProgramConfig, paddle_infer.Config], bool]
    ], reason: IgnoreReasons, note: str):
103
        self.ignore_cases.append((teller, reason, note))
104

W
Wilber 已提交
105
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
J
Jason 已提交
106
        return True
107

108 109 110 111 112 113 114
    def run_test_config(self, model, params, prog_config, pred_config,
                        feed_data) -> Dict[str, np.ndarray]:
        '''
        Test a single case.
        '''
        pred_config.set_model_buffer(model, len(model), params, len(params))
        predictor = paddle_infer.create_predictor(pred_config)
J
Jason 已提交
115 116
        self.available_passes_in_framework = self.available_passes_in_framework | set(
            pred_config.pass_builder().all_passes())
117 118
        for name, _ in prog_config.inputs.items():
            input_tensor = predictor.get_input_handle(name)
119
            input_tensor.copy_from_cpu(feed_data[name]['data'])
W
Wilber 已提交
120 121
            if feed_data[name]['lod'] is not None:
                input_tensor.set_lod(feed_data[name]['lod'])
122 123
        predictor.run()
        result = {}
124 125 126
        for out_name, o_name in zip(prog_config.outputs,
                                    predictor.get_output_names()):
            result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu()
127 128
        return result

W
Wilber 已提交
129
    @abc.abstractmethod
130
    def assert_tensors_near(self, atol: float, rtol: float,
W
Wilber 已提交
131 132 133 134 135 136 137
                            tensor: Dict[str, np.array],
                            baseline: Dict[str, np.array]):
        for key, arr in tensor.items():
            self.assertTrue(
                baseline[key].shape == arr.shape,
                "The output shapes are not equal, the baseline shape is " +
                str(baseline[key].shape) + ', but got ' + str(arr.shape))
138
            diff = abs(baseline[key] - arr)
139 140 141 142 143 144
            np.testing.assert_allclose(
                baseline[key],
                arr,
                rtol=rtol,
                atol=atol,
                err_msg='Output has diff, Maximum absolute error: {}'.format(
145
                    np.amax(diff)))
146

147 148 149
    @abc.abstractmethod
    def run_test(self, quant=False):
        raise NotImplementedError
W
Wilber 已提交
150

151 152
    def generate_op_config(self, ops_config: List[Dict[str,
                                                       Any]]) -> List[OpConfig]:
W
Wilber 已提交
153 154 155 156
        ops = []
        for i in range(len(ops_config)):
            op_config = ops_config[i]
            ops.append(
157 158 159 160
                OpConfig(type=op_config['op_type'],
                         inputs=op_config['op_inputs'],
                         outputs=op_config['op_outputs'],
                         attrs=op_config['op_attrs']))
W
Wilber 已提交
161 162 163
        return ops

    @abc.abstractmethod
164
    def ignore_log(self, msg: str):
W
Wilber 已提交
165 166 167 168
        logging.warning("SKIP: " + msg)

    @abc.abstractmethod
    def fail_log(self, msg: str):
169
        logging.error("FAIL: " + msg)
W
Wilber 已提交
170 171 172 173 174 175 176

    @abc.abstractmethod
    def success_log(self, msg: str):
        logging.info("SUCCESS: " + msg)

    @abc.abstractmethod
    def create_inference_config(self,
177 178 179 180
                                passes: Optional[List[str]] = None,
                                use_gpu: bool = False,
                                use_mkldnn: bool = False,
                                ir_optim: Optional[bool] = None):
W
Wilber 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
        config = paddle_infer.Config()
        config.switch_ir_debug(True)
        config.set_optim_cache_dir(self.cache_dir)
        config.disable_glog_info()
        if ir_optim is not None:
            config.switch_ir_optim(ir_optim)
        if use_gpu:
            config.enable_use_gpu(100, 0)
        if use_mkldnn:
            config.enable_mkldnn()
        if passes is not None:
            config.pass_builder().set_passes(passes)
            self.passes = passes
        return config


class MkldnnAutoScanTest(AutoScanTest):
198

W
Wilber 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    def __init__(self, *args, **kwargs):
        super(MkldnnAutoScanTest, self).__init__(*args, **kwargs)

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: cpu no ir_optim run
            base_config = self.create_inference_config(ir_optim=False)
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
                self.run_test_config(model, params, prog_config, base_config,
                                     feed_data))
            self.success_log('RUN_CPU_BASELINE done')

            for pred_config, (
                    atol, rtol) in self.sample_predictor_configs(prog_config):
                # skip info
233 234 235 236 237 238
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        if ignore_info[
                                1] == IgnoreReasons.MKLDNN_ACCURACY_ERROR:
239 240 241 242
                            self.ignore_log(
                                "[MKLDNN_ACCURACY_ERROR] " + ignore_info[2] +
                                ' ' + ' vs ' +
                                self.inference_config_str(pred_config))
W
Wilber 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
                        self.run_test_config(model, params, prog_config,
                                             pred_config, feed_data))
                    self.assert_tensors_near(atol, rtol, results[-1],
                                             results[0])
                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config) +
                        '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
262
                    if not ignore_flag:
W
Wilber 已提交
263 264
                        status = False
                    continue
265 266 267
                self.success_log('RUN predictor_config ' +
                                 self.inference_config_str(pred_config) +
                                 ' done')
W
Wilber 已提交
268 269 270

        self.assertTrue(status)

W
Wilber 已提交
271
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
272 273 274 275 276 277 278 279 280
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)


class PassAutoScanTest(AutoScanTest):
281

W
Wilber 已提交
282 283 284 285 286 287 288
    def __init__(self, *args, **kwargs):
        super(PassAutoScanTest, self).__init__(*args, **kwargs)
        self.passes = []

    def check_op_version(self):
        status = True
        for pass_name in self.passes:
J
Jason 已提交
289 290
            if pass_name not in self.available_passes_in_framework:
                continue
W
Wilber 已提交
291 292 293 294 295
            if not PassVersionChecker.IsCompatible(pass_name):
                self.fail_log('{} version check failed.'.format(pass_name))
                status = False
        return status

296
    def add_ignore_pass_case(self):
J
Jason 已提交
297 298 299
        return

    def assert_op_list(self, op_list_after_fusion):
W
Wilber 已提交
300 301
        if not self.passes:
            raise ValueError(
J
Jason 已提交
302
                "In PassAutoScan you should give a valid pass name.")
W
Wilber 已提交
303
        last_passed_program = os.path.join(self.cache_dir,
J
Jason 已提交
304 305 306
                                           self.passes[-1] + ".pdmodel")
        if not os.path.exists(last_passed_program):
            raise ValueError(
307 308
                "Cannot find file {}, please make sure that your pass name is correct"
                .format(last_passed_program))
W
Wilber 已提交
309 310 311
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
J
Jason 已提交
312 313 314 315 316 317 318 319
        after_op_list = list()
        for i in range(main_block.op_size()):
            if main_block.op(i).type() in ["feed", "fetch"]:
                continue
            after_op_list.append(main_block.op(i).type())
        self.assertTrue(
            op_list_after_fusion == after_op_list,
            "Expected operator list after fusion is {}, but now it's {}".format(
320 321
                op_list_after_fusion, after_op_list),
        )
W
Wilber 已提交
322

323 324 325 326 327 328
    def run_and_statis(self,
                       quant=False,
                       max_examples=100,
                       reproduce=None,
                       min_success_num=25,
                       max_duration=180,
329
                       passes=None):
J
Jason 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342
        if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
            max_examples *= 10
            min_success_num *= 10
            # while at ce phase, there's no limit on time
            max_duration = -1
        start_time = time.time()
        settings.register_profile(
            "ci",
            max_examples=max_examples,
            suppress_health_check=hypothesis.HealthCheck.all(),
            deadline=None,
            print_blob=True,
            derandomize=True,
343 344
            report_multiple_bugs=False,
        )
J
Jason 已提交
345 346 347 348
        settings.load_profile("ci")
        assert passes is not None, "Parameter of passes must be defined in function run_and_statis."
        self.passes = passes

349
        self.add_ignore_pass_case()
J
Jason 已提交
350 351 352 353 354

        def program_generator(draw):
            return self.sample_program_config(draw)

        def run_test(prog_config):
355
            return self.run_test(quant=quant, prog_configs=[prog_config])
J
Jason 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369

        generator = st.composite(program_generator)
        loop_func = given(generator())(run_test)
        if reproduce is not None:
            loop_func = reproduce(loop_func)
        logging.info("Start to running test of {}".format(type(self)))
        loop_func()
        logging.info(
            "===================Statistical Information===================")
        logging.info("Number of Generated Programs: {}".format(
            self.num_ran_programs + self.num_invalid_programs))
        logging.info("Number of Invalid Programs: {}".format(
            self.num_invalid_programs))
        logging.info("Number of Ran Programs: {}".format(self.num_ran_programs))
370
        logging.info("Number of Ignore Tests: {}".format(self.num_ignore_tests))
J
Jason 已提交
371
        successful_ran_programs = int(self.num_ran_programs -
372 373
                                      self.num_ignore_tests /
                                      max(self.num_predictor_kinds, 1))
J
Jason 已提交
374 375 376 377 378
        logging.info(
            "Number of successfully ran programs approximately equal to {}".
            format(successful_ran_programs))
        if successful_ran_programs < min_success_num:
            logging.warning(
379
                "satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
J
Jason 已提交
380 381
            )
            logging.error(
382 383
                "At least {} programs need to ran successfully, but now only about {} programs satisfied."
                .format(min_success_num, successful_ran_programs))
J
Jason 已提交
384 385 386 387
            assert False
        used_time = time.time() - start_time
        if max_duration > 0 and used_time > max_duration:
            logging.error(
388 389
                "The duration exceeds {} seconds, if this is necessary, try to set a larger number for parameter `max_duration`."
                .format(max_duration))
J
Jason 已提交
390 391
            assert False

392
    def run_test(self, quant=False, prog_configs=None):
W
Wilber 已提交
393 394
        status = True

J
Jason 已提交
395
        for prog_config in prog_configs:
W
Wilber 已提交
396 397
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
J
Jason 已提交
398
                self.num_invalid_programs += 1
W
Wilber 已提交
399
                continue
J
Jason 已提交
400
            self.num_ran_programs += 1
W
Wilber 已提交
401 402 403 404 405 406 407 408 409 410
            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod
                }
411

W
Wilber 已提交
412
            logging.info('RUN program_config: ' + str(prog_config))
J
Jason 已提交
413 414
            self.num_predictor_kinds = 0
            for pred_config, op_list, (
W
Wilber 已提交
415
                    atol, rtol) in self.sample_predictor_configs(prog_config):
J
Jason 已提交
416
                self.num_predictor_kinds += 1
417

W
Wilber 已提交
418
                # skip info
419 420 421 422 423 424
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        self.num_ignore_tests += 1
                        if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
425 426 427 428
                            self.ignore_log(
                                "[PASS_ACCURACY_ERROR] " + ignore_info[2] +
                                ' ' + ' vs ' +
                                self.inference_config_str(pred_config))
W
Wilber 已提交
429 430 431 432 433 434 435 436 437
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

438 439
                # baseline: no ir_optim run
                base_config = self.create_inference_config(
H
heliqi 已提交
440
                    ir_optim=False, use_gpu=pred_config.use_gpu())
W
Wilber 已提交
441
                try:
442
                    # baseline
443 444 445
                    base_result = self.run_test_config(model, params,
                                                       prog_config, base_config,
                                                       feed_data)
446
                    self.success_log('RUN_BASELINE ' +
447 448
                                     self.inference_config_str(base_config) +
                                     ' done')
449 450 451 452

                    if os.path.exists(self.cache_dir):
                        shutil.rmtree(self.cache_dir)

453 454 455
                    pred_result = self.run_test_config(model, params,
                                                       prog_config, pred_config,
                                                       feed_data)
456 457
                    self.assert_tensors_near(atol, rtol, pred_result,
                                             base_result)
458
                    if not ignore_flag:
J
Jason 已提交
459
                        self.assert_op_list(op_list)
W
Wilber 已提交
460 461 462 463 464

                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config) +
                        '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
465
                    if not ignore_flag:
W
Wilber 已提交
466 467
                        status = False
                    continue
468 469 470
                self.success_log('RUN predictor_config ' +
                                 self.inference_config_str(pred_config) +
                                 ' done')
W
Wilber 已提交
471 472 473 474

        status = self.check_op_version() and status
        self.assertTrue(status)

W
Wilber 已提交
475
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        if not self.passes:
            dic['passes'] = self.passes

        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

    def create_trt_inference_config(self) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        config.switch_ir_debug()
        return config


class TrtLayerAutoScanTest(AutoScanTest):
505

W
Wilber 已提交
506 507
    class TensorRTParam:
        '''
508
        TensorRT subgraph engine parameters.
W
Wilber 已提交
509 510 511 512 513 514 515 516 517 518 519 520 521
        '''

        def __init__(self, workspace_size, max_batch_size, min_subgraph_size,
                     precision, use_static, use_calib_mode):
            self.workspace_size = workspace_size
            self.max_batch_size = max_batch_size
            self.min_subgraph_size = min_subgraph_size
            self.precision = precision
            self.use_static = use_static
            self.use_calib_mode = use_calib_mode

    class DynamicShapeParam:
        '''
522
         Prepare TensorRT subgraph engine dynamic shape parameters.
W
Wilber 已提交
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
         '''

        def __init__(self, min_input_shape, max_input_shape, opt_input_shape,
                     disable_trt_plugin_fp16):
            self.min_input_shape = min_input_shape
            self.max_input_shape = max_input_shape
            self.opt_input_shape = opt_input_shape
            self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16

    def __init__(self, *args, **kwargs):
        super(TrtLayerAutoScanTest, self).__init__(*args, **kwargs)
        self.trt_param = self.TensorRTParam(
            workspace_size=1024,
            max_batch_size=4,
            min_subgraph_size=0,
            precision=paddle_infer.PrecisionType.Float32,
            use_static=True,
            use_calib_mode=False)
        self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
        self.num_percent_cases = float(
543
            os.getenv('TEST_NUM_PERCENT_CASES', default='1.0'))
Z
zlsh80826 已提交
544 545 546

        # Use a seperate random generator for skipping tests
        self.skip_rng = np.random.default_rng(int(time.strftime("%W")))
W
Wilber 已提交
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561

    def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        if use_trt:
            config.switch_ir_debug()
            config.enable_tensorrt_engine(
                max_batch_size=self.trt_param.max_batch_size,
                workspace_size=self.trt_param.workspace_size,
                min_subgraph_size=self.trt_param.min_subgraph_size,
                precision_mode=self.trt_param.precision,
                use_static=self.trt_param.use_static,
                use_calib_mode=self.trt_param.use_calib_mode)
Z
zlsh80826 已提交
562 563 564 565
            if self.dynamic_shape.min_input_shape and (
                    self.dynamic_shape.min_input_shape.keys() ==
                    self.dynamic_shape.max_input_shape.keys() ==
                    self.dynamic_shape.opt_input_shape.keys()):
W
Wilber 已提交
566 567 568 569 570 571 572
                config.set_trt_dynamic_shape_info(
                    self.dynamic_shape.min_input_shape,
                    self.dynamic_shape.max_input_shape,
                    self.dynamic_shape.opt_input_shape,
                    self.dynamic_shape.disable_trt_plugin_fp16)
        return config

Z
zlsh80826 已提交
573 574 575 576 577 578 579 580 581 582
    def assert_tensors_near(self, atol: float, rtol: float,
                            tensor: Dict[str, np.array],
                            baseline: Dict[str, np.array]):
        for key, arr in tensor.items():
            self.assertEqual(
                baseline[key].shape, arr.shape,
                'The output shapes are not equal, the baseline shape is ' +
                str(baseline[key].shape) + ', but got ' + str(arr.shape))
            np.testing.assert_allclose(baseline[key], arr, rtol=rtol, atol=atol)

W
Wilber 已提交
583 584 585 586 587 588 589 590 591 592 593 594
    def assert_op_size(self, trt_engine_num, paddle_op_num):
        last_passed_program = os.path.join(
            self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel')
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
        op_size = main_block.op_size()
        op_types = [
            main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size)
        ]
        trt_engine_size = sum(op_types)
        paddle_op_size = op_size - trt_engine_size
Z
zlsh80826 已提交
595 596 597 598 599 600 601 602
        self.assertEqual(
            trt_engine_num, trt_engine_size,
            'Expected trt_engine_num is {}, but got {}!'.format(
                trt_engine_num, trt_engine_size))
        self.assertEqual(
            paddle_op_num, paddle_op_size,
            'Expected paddle_op_num is {}, but got {}!'.format(
                paddle_op_num, paddle_op_size))
W
Wilber 已提交
603

W
Wilber 已提交
604
    def inference_config_str(self, config: paddle_infer.Config) -> str:
W
Wilber 已提交
605 606 607 608 609 610 611 612 613 614 615 616
        dic = {}
        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

617
    def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
Z
zlsh80826 已提交
618 619 620 621 622 623 624
        all_passes = True

        def random_to_skip():
            if self.skip_rng.random() < self.num_percent_cases:
                return False
            return True

W
Wilber 已提交
625
        for prog_config in self.sample_program_configs(*args, **kwargs):
Z
zlsh80826 已提交
626 627

            if random_to_skip():
W
Wilber 已提交
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
                continue

            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod
                }

            results: List[Dict[str, np.ndarray]] = []
646 647 648 649 650 651 652 653
            if not skip_baseline:
                #baseline: gpu run
                logging.info('RUN program_config: ' + str(prog_config))
                gpu_config = self.create_inference_config(use_trt=False)
                results.append(
                    self.run_test_config(model, params, prog_config, gpu_config,
                                         feed_data))
                self.success_log('RUN_GPU_BASELINE done')
W
Wilber 已提交
654 655 656 657 658 659 660 661 662 663

            for pred_config, nodes_num, threshold in self.sample_predictor_configs(
                    prog_config):

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)

                if isinstance(threshold, float):
                    atol = threshold
                    rtol = 1e-8
664 665
                elif isinstance(threshold, list) or isinstance(
                        threshold, tuple):
W
Wilber 已提交
666 667 668 669 670
                    atol = threshold[0]
                    rtol = threshold[1]
                else:
                    raise NotImplementedError

Z
zlsh80826 已提交
671 672
                if pred_config.tensorrt_precision_mode(
                ) != paddle_infer.PrecisionType.Int8 and quant:
W
Wilber 已提交
673 674 675 676 677
                    continue
                if pred_config.tensorrt_precision_mode(
                ) == paddle_infer.PrecisionType.Int8 and not quant:
                    continue

678
                ignore_flag = False
Z
zlsh80826 已提交
679 680
                for teller, reason, note in self.ignore_cases:
                    if teller(prog_config, pred_config):
681
                        ignore_flag = True
Z
zlsh80826 已提交
682
                        if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
683
                            self.ignore_log(
Z
zlsh80826 已提交
684 685 686 687 688 689
                                '[TRT_NOT_IMPLEMENTED] {} vs {}'.format(
                                    note,
                                    self.inference_config_str(pred_config)))
                        elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
                            self.ignore_log('[TRT_NOT_SUPPORT] {} vs {}'.format(
                                note, self.inference_config_str(pred_config)))
W
Wilber 已提交
690 691 692 693
                        else:
                            raise NotImplementedError
                        break

Z
zlsh80826 已提交
694 695 696
                if ignore_flag:
                    continue

W
Wilber 已提交
697 698 699 700 701 702 703
                try:
                    pred_config_deserialize = paddle_infer.Config(pred_config)
                    results.append(
                        self.run_test_config(model, params, prog_config,
                                             pred_config, feed_data))
                    self.assert_tensors_near(atol, rtol, results[-1],
                                             results[0])
Z
zlsh80826 已提交
704 705 706
                    trt_engine_num, paddle_op_num = nodes_num
                    self.assert_op_size(trt_engine_num, paddle_op_num)

W
Wilber 已提交
707
                    # deserialize test
Z
zlsh80826 已提交
708
                    if trt_engine_num > 0:
W
Wilber 已提交
709 710
                        self.run_test_config(model, params, prog_config,
                                             pred_config_deserialize, feed_data)
Z
zlsh80826 已提交
711 712 713

                    self.success_log('RUN predictor_config {} done'.format(
                        self.inference_config_str(pred_config)))
W
Wilber 已提交
714 715
                except Exception as e:
                    self.fail_log(
W
Wilber 已提交
716
                        self.inference_config_str(pred_config) +
W
Wilber 已提交
717
                        '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)))
Z
zlsh80826 已提交
718
                    all_passes = False
W
Wilber 已提交
719

Z
zlsh80826 已提交
720
        self.assertTrue(all_passes)
721 722

    # TODO(wilber): just for backward compatible
723 724 725
    def add_skip_case(self, teller: [
        Callable[[ProgramConfig, paddle_infer.Config], bool]
    ], reason: IgnoreReasons, note: str):
726
        self.ignore_cases.append((teller, reason, note))