auto_scan_test.py 34.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
16
import enum
W
Wilber 已提交
17
import logging
18
import os
W
Wilber 已提交
19
import shutil
20 21 22 23 24 25 26 27
import time
import unittest
from typing import Any, Callable, Dict, List, Optional

import hypothesis
import hypothesis.strategies as st
import numpy as np
from hypothesis import given, settings
28 29 30 31 32 33
from program_config import (
    OpConfig,
    ProgramConfig,
    create_fake_model,
    create_quant_model,
)
34

35 36 37
import paddle
import paddle.inference as paddle_infer
from paddle.fluid.core import PassVersionChecker
W
Wilber 已提交
38

W
Wilber 已提交
39 40
logging.basicConfig(level=logging.INFO, format="%(message)s")

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
settings.register_profile(
    "ci",
    max_examples=100,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False,
)
settings.register_profile(
    "dev",
    max_examples=1000,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False,
)
if (
    float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1
    or os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci'
):
W
Wilber 已提交
63 64 65 66
    settings.load_profile("ci")
else:
    settings.load_profile("dev")

67

68
class IgnoreReasons(enum.Enum):
69 70 71 72
    # Paddle not support, but trt support, we need to add the feature.
    TRT_NOT_IMPLEMENTED = 0
    # TRT not support.
    TRT_NOT_SUPPORT = 1
W
Wilber 已提交
73 74 75 76
    # Accuracy is abnormal after enabling pass.
    PASS_ACCURACY_ERROR = 2
    # Accuracy is abnormal after enabling mkldnn.
    MKLDNN_ACCURACY_ERROR = 3
77 78
    # Accuracy is abnormal after enabling cutlass.
    CUTLASS_ACCURACY_ERROR = 3
79 80


81 82 83 84
# TODO(wilber): just for backward compatible
SkipReasons = IgnoreReasons


85
class AutoScanTest(unittest.TestCase):
W
Wilber 已提交
86
    def __init__(self, *args, **kwargs):
W
Wilber 已提交
87
        np.random.seed(1024)
88
        paddle.enable_static()
89
        super().__init__(*args, **kwargs)
90
        self.ignore_cases = []
W
Wilber 已提交
91
        abs_dir = os.path.abspath(os.path.dirname(__file__))
92 93 94
        self.cache_dir = os.path.join(
            abs_dir, str(self.__module__) + '_cache_dir'
        )
J
Jason 已提交
95 96 97
        self.available_passes_in_framework = set()
        self.num_ran_programs = 0
        self.num_invalid_programs = 0
98
        self.num_ignore_tests = 0
J
Jason 已提交
99
        self.num_predictor_kinds = 0
100 101

    @abc.abstractmethod
W
Wilber 已提交
102
    def sample_program_configs(self):
103 104 105 106 107 108 109
        '''
        Generate all config with the combination of different Input tensor shape and
        different Attr values.
        '''
        raise NotImplementedError

    @abc.abstractmethod
W
Wilber 已提交
110
    def sample_predictor_configs(self):
111 112
        raise NotImplementedError

113
    @abc.abstractmethod
114 115 116 117 118 119
    def add_ignore_check_case(
        self,
        teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
        reason: IgnoreReasons,
        note: str,
    ):
120
        self.ignore_cases.append((teller, reason, note))
121

W
Wilber 已提交
122
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
J
Jason 已提交
123
        return True
124

125 126 127
    def run_test_config(
        self, model, params, prog_config, pred_config, feed_data
    ) -> Dict[str, np.ndarray]:
128 129 130 131 132
        '''
        Test a single case.
        '''
        pred_config.set_model_buffer(model, len(model), params, len(params))
        predictor = paddle_infer.create_predictor(pred_config)
133 134 135 136
        self.available_passes_in_framework = (
            self.available_passes_in_framework
            | set(pred_config.pass_builder().all_passes())
        )
137 138
        for name, _ in prog_config.inputs.items():
            input_tensor = predictor.get_input_handle(name)
139
            input_tensor.copy_from_cpu(feed_data[name]['data'])
W
Wilber 已提交
140 141
            if feed_data[name]['lod'] is not None:
                input_tensor.set_lod(feed_data[name]['lod'])
142 143
        predictor.run()
        result = {}
144 145 146
        for out_name, o_name in zip(
            prog_config.outputs, predictor.get_output_names()
        ):
147
            result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu()
148 149
        return result

W
Wilber 已提交
150
    @abc.abstractmethod
151 152 153 154 155 156 157
    def assert_tensors_near(
        self,
        atol: float,
        rtol: float,
        tensor: Dict[str, np.array],
        baseline: Dict[str, np.array],
    ):
W
Wilber 已提交
158 159 160
        for key, arr in tensor.items():
            self.assertTrue(
                baseline[key].shape == arr.shape,
161 162 163 164 165
                "The output shapes are not equal, the baseline shape is "
                + str(baseline[key].shape)
                + ', but got '
                + str(arr.shape),
            )
166
            diff = abs(baseline[key] - arr)
167 168 169 170 171 172
            np.testing.assert_allclose(
                baseline[key],
                arr,
                rtol=rtol,
                atol=atol,
                err_msg='Output has diff, Maximum absolute error: {}'.format(
173 174 175
                    np.amax(diff)
                ),
            )
176

177 178 179
    @abc.abstractmethod
    def run_test(self, quant=False):
        raise NotImplementedError
W
Wilber 已提交
180

181 182 183
    def generate_op_config(
        self, ops_config: List[Dict[str, Any]]
    ) -> List[OpConfig]:
W
Wilber 已提交
184 185 186
        ops = []
        for i in range(len(ops_config)):
            op_config = ops_config[i]
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
            if 'outputs_dtype' in op_config:
                ops.append(
                    OpConfig(
                        type=op_config['op_type'],
                        inputs=op_config['op_inputs'],
                        outputs=op_config['op_outputs'],
                        attrs=op_config['op_attrs'],
                        outputs_dtype=op_config['outputs_dtype'],
                    )
                )
            else:
                ops.append(
                    OpConfig(
                        type=op_config['op_type'],
                        inputs=op_config['op_inputs'],
                        outputs=op_config['op_outputs'],
                        attrs=op_config['op_attrs'],
                    )
205
                )
W
Wilber 已提交
206 207 208
        return ops

    @abc.abstractmethod
209
    def ignore_log(self, msg: str):
W
Wilber 已提交
210 211 212 213
        logging.warning("SKIP: " + msg)

    @abc.abstractmethod
    def fail_log(self, msg: str):
214
        logging.error("FAIL: " + msg)
W
Wilber 已提交
215 216 217 218 219 220

    @abc.abstractmethod
    def success_log(self, msg: str):
        logging.info("SUCCESS: " + msg)

    @abc.abstractmethod
221 222 223 224 225
    def create_inference_config(
        self,
        passes: Optional[List[str]] = None,
        use_gpu: bool = False,
        use_mkldnn: bool = False,
226
        use_xpu: bool = False,
227 228
        ir_optim: Optional[bool] = None,
    ):
W
Wilber 已提交
229 230 231 232 233 234 235 236 237 238
        config = paddle_infer.Config()
        config.switch_ir_debug(True)
        config.set_optim_cache_dir(self.cache_dir)
        config.disable_glog_info()
        if ir_optim is not None:
            config.switch_ir_optim(ir_optim)
        if use_gpu:
            config.enable_use_gpu(100, 0)
        if use_mkldnn:
            config.enable_mkldnn()
239 240
        if use_xpu:
            config.enable_xpu()
W
Wilber 已提交
241 242 243 244 245 246 247 248
        if passes is not None:
            config.pass_builder().set_passes(passes)
            self.passes = passes
        return config


class MkldnnAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
249
        super().__init__(*args, **kwargs)
W
Wilber 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
267
                    'lod': tensor_config.lod,
W
Wilber 已提交
268 269 270 271 272 273 274
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: cpu no ir_optim run
            base_config = self.create_inference_config(ir_optim=False)
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
275 276 277 278
                self.run_test_config(
                    model, params, prog_config, base_config, feed_data
                )
            )
W
Wilber 已提交
279 280
            self.success_log('RUN_CPU_BASELINE done')

281 282 283
            for pred_config, (atol, rtol) in self.sample_predictor_configs(
                prog_config
            ):
W
Wilber 已提交
284
                # skip info
285 286 287 288
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
289 290 291 292
                        if (
                            ignore_info[1]
                            == IgnoreReasons.MKLDNN_ACCURACY_ERROR
                        ):
293
                            self.ignore_log(
294 295 296 297 298 299
                                "[MKLDNN_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
W
Wilber 已提交
300 301 302 303 304 305 306 307 308 309 310
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
311 312 313 314 315 316 317
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
W
Wilber 已提交
318 319
                except Exception as e:
                    self.fail_log(
320 321 322
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
323
                    if not ignore_flag:
W
Wilber 已提交
324 325
                        status = False
                    continue
326 327 328 329 330
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )
W
Wilber 已提交
331 332 333

        self.assertTrue(status)

W
Wilber 已提交
334
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
335 336 337 338 339 340 341 342 343 344
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)


class PassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
345
        super().__init__(*args, **kwargs)
W
Wilber 已提交
346 347 348 349 350
        self.passes = []

    def check_op_version(self):
        status = True
        for pass_name in self.passes:
J
Jason 已提交
351 352
            if pass_name not in self.available_passes_in_framework:
                continue
W
Wilber 已提交
353 354 355 356 357
            if not PassVersionChecker.IsCompatible(pass_name):
                self.fail_log('{} version check failed.'.format(pass_name))
                status = False
        return status

358
    def add_ignore_pass_case(self):
J
Jason 已提交
359 360 361
        return

    def assert_op_list(self, op_list_after_fusion):
W
Wilber 已提交
362 363
        if not self.passes:
            raise ValueError(
364 365 366 367 368
                "In PassAutoScan you should give a valid pass name."
            )
        last_passed_program = os.path.join(
            self.cache_dir, self.passes[-1] + ".pdmodel"
        )
J
Jason 已提交
369 370
        if not os.path.exists(last_passed_program):
            raise ValueError(
371 372 373 374
                "Cannot find file {}, please make sure that your pass name is correct".format(
                    last_passed_program
                )
            )
W
Wilber 已提交
375 376 377
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
J
Jason 已提交
378 379 380 381 382 383 384 385
        after_op_list = list()
        for i in range(main_block.op_size()):
            if main_block.op(i).type() in ["feed", "fetch"]:
                continue
            after_op_list.append(main_block.op(i).type())
        self.assertTrue(
            op_list_after_fusion == after_op_list,
            "Expected operator list after fusion is {}, but now it's {}".format(
386 387
                op_list_after_fusion, after_op_list
            ),
388
        )
W
Wilber 已提交
389

390 391 392 393 394 395 396 397 398
    def run_and_statis(
        self,
        quant=False,
        max_examples=100,
        reproduce=None,
        min_success_num=25,
        max_duration=180,
        passes=None,
    ):
J
Jason 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411
        if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
            max_examples *= 10
            min_success_num *= 10
            # while at ce phase, there's no limit on time
            max_duration = -1
        start_time = time.time()
        settings.register_profile(
            "ci",
            max_examples=max_examples,
            suppress_health_check=hypothesis.HealthCheck.all(),
            deadline=None,
            print_blob=True,
            derandomize=True,
412 413
            report_multiple_bugs=False,
        )
J
Jason 已提交
414
        settings.load_profile("ci")
415 416 417
        assert (
            passes is not None
        ), "Parameter of passes must be defined in function run_and_statis."
J
Jason 已提交
418 419
        self.passes = passes

420
        self.add_ignore_pass_case()
J
Jason 已提交
421 422 423 424 425

        def program_generator(draw):
            return self.sample_program_config(draw)

        def run_test(prog_config):
426
            return self.run_test(quant=quant, prog_configs=[prog_config])
J
Jason 已提交
427 428 429 430 431 432 433 434

        generator = st.composite(program_generator)
        loop_func = given(generator())(run_test)
        if reproduce is not None:
            loop_func = reproduce(loop_func)
        logging.info("Start to running test of {}".format(type(self)))
        loop_func()
        logging.info(
435 436 437 438 439 440 441 442 443 444
            "===================Statistical Information==================="
        )
        logging.info(
            "Number of Generated Programs: {}".format(
                self.num_ran_programs + self.num_invalid_programs
            )
        )
        logging.info(
            "Number of Invalid Programs: {}".format(self.num_invalid_programs)
        )
J
Jason 已提交
445
        logging.info("Number of Ran Programs: {}".format(self.num_ran_programs))
446
        logging.info("Number of Ignore Tests: {}".format(self.num_ignore_tests))
447 448 449 450
        successful_ran_programs = int(
            self.num_ran_programs
            - self.num_ignore_tests / max(self.num_predictor_kinds, 1)
        )
J
Jason 已提交
451
        logging.info(
452 453 454 455
            "Number of successfully ran programs approximately equal to {}".format(
                successful_ran_programs
            )
        )
J
Jason 已提交
456 457
        if successful_ran_programs < min_success_num:
            logging.warning(
458
                "satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
J
Jason 已提交
459 460
            )
            logging.error(
461 462 463 464
                "At least {} programs need to ran successfully, but now only about {} programs satisfied.".format(
                    min_success_num, successful_ran_programs
                )
            )
J
Jason 已提交
465 466 467 468
            assert False
        used_time = time.time() - start_time
        if max_duration > 0 and used_time > max_duration:
            logging.error(
469 470 471 472
                "The duration exceeds {} seconds, if this is necessary, try to set a larger number for parameter `max_duration`.".format(
                    max_duration
                )
            )
J
Jason 已提交
473 474
            assert False

475
    def run_test(self, quant=False, prog_configs=None):
W
Wilber 已提交
476 477
        status = True

J
Jason 已提交
478
        for prog_config in prog_configs:
W
Wilber 已提交
479 480
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
J
Jason 已提交
481
                self.num_invalid_programs += 1
W
Wilber 已提交
482
                continue
J
Jason 已提交
483
            self.num_ran_programs += 1
W
Wilber 已提交
484 485 486 487 488 489 490 491
            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
492
                    'lod': tensor_config.lod,
W
Wilber 已提交
493
                }
494

W
Wilber 已提交
495
            logging.info('RUN program_config: ' + str(prog_config))
J
Jason 已提交
496
            self.num_predictor_kinds = 0
497 498 499 500 501
            for (
                pred_config,
                op_list,
                (atol, rtol),
            ) in self.sample_predictor_configs(prog_config):
J
Jason 已提交
502
                self.num_predictor_kinds += 1
503

W
Wilber 已提交
504
                # skip info
505 506 507 508 509 510
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        self.num_ignore_tests += 1
                        if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
511
                            self.ignore_log(
512 513 514 515 516 517
                                "[PASS_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
W
Wilber 已提交
518 519 520 521 522 523 524 525 526
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

527 528
                # baseline: no ir_optim run
                base_config = self.create_inference_config(
529 530
                    ir_optim=False, use_gpu=pred_config.use_gpu()
                )
W
Wilber 已提交
531
                try:
532
                    # baseline
533 534 535 536 537 538 539 540
                    base_result = self.run_test_config(
                        model, params, prog_config, base_config, feed_data
                    )
                    self.success_log(
                        'RUN_BASELINE '
                        + self.inference_config_str(base_config)
                        + ' done'
                    )
541 542 543 544

                    if os.path.exists(self.cache_dir):
                        shutil.rmtree(self.cache_dir)

545 546 547 548 549 550
                    pred_result = self.run_test_config(
                        model, params, prog_config, pred_config, feed_data
                    )
                    self.assert_tensors_near(
                        atol, rtol, pred_result, base_result
                    )
551
                    if not ignore_flag:
J
Jason 已提交
552
                        self.assert_op_list(op_list)
W
Wilber 已提交
553 554 555

                except Exception as e:
                    self.fail_log(
556 557 558
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
559
                    if not ignore_flag:
W
Wilber 已提交
560 561
                        status = False
                    continue
562 563 564 565 566
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )
W
Wilber 已提交
567 568 569 570

        status = self.check_op_version() and status
        self.assertTrue(status)

W
Wilber 已提交
571
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
572 573 574 575 576
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
577 578
        enable_xpu = config.use_xpu()
        dic['use_xpu'] = enable_xpu
W
Wilber 已提交
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
        if not self.passes:
            dic['passes'] = self.passes

        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

    def create_trt_inference_config(self) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        config.switch_ir_debug()
        return config


class TrtLayerAutoScanTest(AutoScanTest):
    class TensorRTParam:
        '''
605
        TensorRT subgraph engine parameters.
W
Wilber 已提交
606 607
        '''

608 609 610 611 612 613 614 615 616
        def __init__(
            self,
            workspace_size,
            max_batch_size,
            min_subgraph_size,
            precision,
            use_static,
            use_calib_mode,
        ):
W
Wilber 已提交
617 618 619 620 621 622 623 624 625
            self.workspace_size = workspace_size
            self.max_batch_size = max_batch_size
            self.min_subgraph_size = min_subgraph_size
            self.precision = precision
            self.use_static = use_static
            self.use_calib_mode = use_calib_mode

    class DynamicShapeParam:
        '''
626 627
        Prepare TensorRT subgraph engine dynamic shape parameters.
        '''
W
Wilber 已提交
628

629 630 631 632 633 634 635
        def __init__(
            self,
            min_input_shape,
            max_input_shape,
            opt_input_shape,
            disable_trt_plugin_fp16,
        ):
W
Wilber 已提交
636 637 638 639 640 641
            self.min_input_shape = min_input_shape
            self.max_input_shape = max_input_shape
            self.opt_input_shape = opt_input_shape
            self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16

    def __init__(self, *args, **kwargs):
642
        super().__init__(*args, **kwargs)
W
Wilber 已提交
643 644 645 646 647 648
        self.trt_param = self.TensorRTParam(
            workspace_size=1024,
            max_batch_size=4,
            min_subgraph_size=0,
            precision=paddle_infer.PrecisionType.Float32,
            use_static=True,
649 650
            use_calib_mode=False,
        )
W
Wilber 已提交
651 652
        self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
        self.num_percent_cases = float(
653 654
            os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')
        )
Z
zlsh80826 已提交
655 656 657

        # Use a seperate random generator for skipping tests
        self.skip_rng = np.random.default_rng(int(time.strftime("%W")))
W
Wilber 已提交
658 659 660 661 662 663 664 665 666 667 668 669 670 671

    def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        if use_trt:
            config.switch_ir_debug()
            config.enable_tensorrt_engine(
                max_batch_size=self.trt_param.max_batch_size,
                workspace_size=self.trt_param.workspace_size,
                min_subgraph_size=self.trt_param.min_subgraph_size,
                precision_mode=self.trt_param.precision,
                use_static=self.trt_param.use_static,
672 673
                use_calib_mode=self.trt_param.use_calib_mode,
            )
Z
zlsh80826 已提交
674
            if self.dynamic_shape.min_input_shape and (
675 676 677 678
                self.dynamic_shape.min_input_shape.keys()
                == self.dynamic_shape.max_input_shape.keys()
                == self.dynamic_shape.opt_input_shape.keys()
            ):
W
Wilber 已提交
679 680 681 682
                config.set_trt_dynamic_shape_info(
                    self.dynamic_shape.min_input_shape,
                    self.dynamic_shape.max_input_shape,
                    self.dynamic_shape.opt_input_shape,
683 684
                    self.dynamic_shape.disable_trt_plugin_fp16,
                )
W
Wilber 已提交
685 686
        return config

687 688 689 690 691 692 693
    def assert_tensors_near(
        self,
        atol: float,
        rtol: float,
        tensor: Dict[str, np.array],
        baseline: Dict[str, np.array],
    ):
Z
zlsh80826 已提交
694 695
        for key, arr in tensor.items():
            self.assertEqual(
696 697 698 699 700 701 702
                baseline[key].shape,
                arr.shape,
                'The output shapes are not equal, the baseline shape is '
                + str(baseline[key].shape)
                + ', but got '
                + str(arr.shape),
            )
Z
zlsh80826 已提交
703 704
            np.testing.assert_allclose(baseline[key], arr, rtol=rtol, atol=atol)

W
Wilber 已提交
705 706
    def assert_op_size(self, trt_engine_num, paddle_op_num):
        last_passed_program = os.path.join(
707 708
            self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel'
        )
W
Wilber 已提交
709 710 711 712 713 714 715 716 717
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
        op_size = main_block.op_size()
        op_types = [
            main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size)
        ]
        trt_engine_size = sum(op_types)
        paddle_op_size = op_size - trt_engine_size
Z
zlsh80826 已提交
718
        self.assertEqual(
719 720
            trt_engine_num,
            trt_engine_size,
Z
zlsh80826 已提交
721
            'Expected trt_engine_num is {}, but got {}!'.format(
722 723 724
                trt_engine_num, trt_engine_size
            ),
        )
Z
zlsh80826 已提交
725
        self.assertEqual(
726 727
            paddle_op_num,
            paddle_op_size,
Z
zlsh80826 已提交
728
            'Expected paddle_op_num is {}, but got {}!'.format(
729 730 731
                paddle_op_num, paddle_op_size
            ),
        )
W
Wilber 已提交
732

W
Wilber 已提交
733
    def inference_config_str(self, config: paddle_infer.Config) -> str:
W
Wilber 已提交
734 735 736 737 738 739 740 741 742 743 744 745
        dic = {}
        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

746
    def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
Z
zlsh80826 已提交
747 748 749 750 751 752 753
        all_passes = True

        def random_to_skip():
            if self.skip_rng.random() < self.num_percent_cases:
                return False
            return True

W
Wilber 已提交
754
        for prog_config in self.sample_program_configs(*args, **kwargs):
Z
zlsh80826 已提交
755 756

            if random_to_skip():
W
Wilber 已提交
757 758 759 760 761 762 763 764 765 766 767 768 769 770
                continue

            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
771
                    'lod': tensor_config.lod,
W
Wilber 已提交
772 773 774
                }

            results: List[Dict[str, np.ndarray]] = []
775
            if not skip_baseline:
776
                # baseline: gpu run
777 778 779
                logging.info('RUN program_config: ' + str(prog_config))
                gpu_config = self.create_inference_config(use_trt=False)
                results.append(
780 781 782 783
                    self.run_test_config(
                        model, params, prog_config, gpu_config, feed_data
                    )
                )
784
                self.success_log('RUN_GPU_BASELINE done')
W
Wilber 已提交
785

786 787 788 789 790
            for (
                pred_config,
                nodes_num,
                threshold,
            ) in self.sample_predictor_configs(prog_config):
W
Wilber 已提交
791 792 793 794 795 796 797

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)

                if isinstance(threshold, float):
                    atol = threshold
                    rtol = 1e-8
798
                elif isinstance(threshold, list) or isinstance(
799 800
                    threshold, tuple
                ):
W
Wilber 已提交
801 802 803 804 805
                    atol = threshold[0]
                    rtol = threshold[1]
                else:
                    raise NotImplementedError

806 807 808 809 810
                if (
                    pred_config.tensorrt_precision_mode()
                    != paddle_infer.PrecisionType.Int8
                    and quant
                ):
W
Wilber 已提交
811
                    continue
812 813 814 815 816
                if (
                    pred_config.tensorrt_precision_mode()
                    == paddle_infer.PrecisionType.Int8
                    and not quant
                ):
W
Wilber 已提交
817 818
                    continue

819
                ignore_flag = False
Z
zlsh80826 已提交
820 821
                for teller, reason, note in self.ignore_cases:
                    if teller(prog_config, pred_config):
822
                        ignore_flag = True
Z
zlsh80826 已提交
823
                        if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
824
                            self.ignore_log(
Z
zlsh80826 已提交
825
                                '[TRT_NOT_IMPLEMENTED] {} vs {}'.format(
826 827 828
                                    note, self.inference_config_str(pred_config)
                                )
                            )
Z
zlsh80826 已提交
829
                        elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
830 831 832 833 834
                            self.ignore_log(
                                '[TRT_NOT_SUPPORT] {} vs {}'.format(
                                    note, self.inference_config_str(pred_config)
                                )
                            )
W
Wilber 已提交
835 836 837 838
                        else:
                            raise NotImplementedError
                        break

Z
zlsh80826 已提交
839 840 841
                if ignore_flag:
                    continue

W
Wilber 已提交
842 843 844
                try:
                    pred_config_deserialize = paddle_infer.Config(pred_config)
                    results.append(
845 846 847 848 849 850 851
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
Z
zlsh80826 已提交
852 853 854
                    trt_engine_num, paddle_op_num = nodes_num
                    self.assert_op_size(trt_engine_num, paddle_op_num)

W
Wilber 已提交
855
                    # deserialize test
Z
zlsh80826 已提交
856
                    if trt_engine_num > 0:
857 858 859 860 861 862 863 864 865 866 867 868 869
                        self.run_test_config(
                            model,
                            params,
                            prog_config,
                            pred_config_deserialize,
                            feed_data,
                        )

                    self.success_log(
                        'RUN predictor_config {} done'.format(
                            self.inference_config_str(pred_config)
                        )
                    )
W
Wilber 已提交
870 871
                except Exception as e:
                    self.fail_log(
872 873 874
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
Z
zlsh80826 已提交
875
                    all_passes = False
W
Wilber 已提交
876

Z
zlsh80826 已提交
877
        self.assertTrue(all_passes)
878 879

    # TODO(wilber): just for backward compatible
880 881 882 883 884 885
    def add_skip_case(
        self,
        teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
        reason: IgnoreReasons,
        note: str,
    ):
886
        self.ignore_cases.append((teller, reason, note))
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979


class CutlassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod,
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: gpu no ir_optim run
            base_config = self.create_inference_config(
                ir_optim=False, use_gpu=True
            )
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
                self.run_test_config(
                    model, params, prog_config, base_config, feed_data
                )
            )
            self.success_log('RUN_GPU_BASELINE done')

            for pred_config, (atol, rtol) in self.sample_predictor_configs(
                prog_config
            ):
                # skip info
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        if (
                            ignore_info[1]
                            == IgnoreReasons.CUTLASS_ACCURACY_ERROR
                        ):
                            self.ignore_log(
                                "[CUTLASS_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
                    if not ignore_flag:
                        status = False
                    continue
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )

        self.assertTrue(status)

    def inference_config_str(self, config) -> str:
        dic = {}
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)