auto_scan_test.py 33.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
16
import enum
W
Wilber 已提交
17
import logging
18
import os
W
Wilber 已提交
19
import shutil
20 21 22 23 24 25 26 27
import time
import unittest
from typing import Any, Callable, Dict, List, Optional

import hypothesis
import hypothesis.strategies as st
import numpy as np
from hypothesis import given, settings
28 29 30 31 32 33
from program_config import (
    OpConfig,
    ProgramConfig,
    create_fake_model,
    create_quant_model,
)
34

35 36 37
import paddle
import paddle.inference as paddle_infer
from paddle.fluid.core import PassVersionChecker
W
Wilber 已提交
38

W
Wilber 已提交
39 40
logging.basicConfig(level=logging.INFO, format="%(message)s")

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
settings.register_profile(
    "ci",
    max_examples=100,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False,
)
settings.register_profile(
    "dev",
    max_examples=1000,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False,
)
if (
    float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1
    or os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci'
):
W
Wilber 已提交
63 64 65 66
    settings.load_profile("ci")
else:
    settings.load_profile("dev")

67

68
class IgnoreReasons(enum.Enum):
69 70 71 72
    # Paddle not support, but trt support, we need to add the feature.
    TRT_NOT_IMPLEMENTED = 0
    # TRT not support.
    TRT_NOT_SUPPORT = 1
W
Wilber 已提交
73 74 75 76
    # Accuracy is abnormal after enabling pass.
    PASS_ACCURACY_ERROR = 2
    # Accuracy is abnormal after enabling mkldnn.
    MKLDNN_ACCURACY_ERROR = 3
77 78
    # Accuracy is abnormal after enabling cutlass.
    CUTLASS_ACCURACY_ERROR = 3
79 80


81 82 83 84
# TODO(wilber): just for backward compatible
SkipReasons = IgnoreReasons


85
class AutoScanTest(unittest.TestCase):
W
Wilber 已提交
86
    def __init__(self, *args, **kwargs):
W
Wilber 已提交
87
        np.random.seed(1024)
88
        paddle.enable_static()
89
        super().__init__(*args, **kwargs)
90
        self.ignore_cases = []
W
Wilber 已提交
91
        abs_dir = os.path.abspath(os.path.dirname(__file__))
92 93 94
        self.cache_dir = os.path.join(
            abs_dir, str(self.__module__) + '_cache_dir'
        )
J
Jason 已提交
95 96 97
        self.available_passes_in_framework = set()
        self.num_ran_programs = 0
        self.num_invalid_programs = 0
98
        self.num_ignore_tests = 0
J
Jason 已提交
99
        self.num_predictor_kinds = 0
100 101

    @abc.abstractmethod
W
Wilber 已提交
102
    def sample_program_configs(self):
103 104 105 106 107 108 109
        '''
        Generate all config with the combination of different Input tensor shape and
        different Attr values.
        '''
        raise NotImplementedError

    @abc.abstractmethod
W
Wilber 已提交
110
    def sample_predictor_configs(self):
111 112
        raise NotImplementedError

113
    @abc.abstractmethod
114 115 116 117 118 119
    def add_ignore_check_case(
        self,
        teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
        reason: IgnoreReasons,
        note: str,
    ):
120
        self.ignore_cases.append((teller, reason, note))
121

W
Wilber 已提交
122
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
J
Jason 已提交
123
        return True
124

125 126 127
    def run_test_config(
        self, model, params, prog_config, pred_config, feed_data
    ) -> Dict[str, np.ndarray]:
128 129 130 131 132
        '''
        Test a single case.
        '''
        pred_config.set_model_buffer(model, len(model), params, len(params))
        predictor = paddle_infer.create_predictor(pred_config)
133 134 135 136
        self.available_passes_in_framework = (
            self.available_passes_in_framework
            | set(pred_config.pass_builder().all_passes())
        )
137 138
        for name, _ in prog_config.inputs.items():
            input_tensor = predictor.get_input_handle(name)
139
            input_tensor.copy_from_cpu(feed_data[name]['data'])
W
Wilber 已提交
140 141
            if feed_data[name]['lod'] is not None:
                input_tensor.set_lod(feed_data[name]['lod'])
142 143
        predictor.run()
        result = {}
144 145 146
        for out_name, o_name in zip(
            prog_config.outputs, predictor.get_output_names()
        ):
147
            result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu()
148 149
        return result

W
Wilber 已提交
150
    @abc.abstractmethod
151 152 153 154 155 156 157
    def assert_tensors_near(
        self,
        atol: float,
        rtol: float,
        tensor: Dict[str, np.array],
        baseline: Dict[str, np.array],
    ):
W
Wilber 已提交
158 159 160
        for key, arr in tensor.items():
            self.assertTrue(
                baseline[key].shape == arr.shape,
161 162 163 164 165
                "The output shapes are not equal, the baseline shape is "
                + str(baseline[key].shape)
                + ', but got '
                + str(arr.shape),
            )
166
            diff = abs(baseline[key] - arr)
167 168 169 170 171 172
            np.testing.assert_allclose(
                baseline[key],
                arr,
                rtol=rtol,
                atol=atol,
                err_msg='Output has diff, Maximum absolute error: {}'.format(
173 174 175
                    np.amax(diff)
                ),
            )
176

177 178 179
    @abc.abstractmethod
    def run_test(self, quant=False):
        raise NotImplementedError
W
Wilber 已提交
180

181 182 183
    def generate_op_config(
        self, ops_config: List[Dict[str, Any]]
    ) -> List[OpConfig]:
W
Wilber 已提交
184 185 186
        ops = []
        for i in range(len(ops_config)):
            op_config = ops_config[i]
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
            if 'outputs_dtype' in op_config:
                ops.append(
                    OpConfig(
                        type=op_config['op_type'],
                        inputs=op_config['op_inputs'],
                        outputs=op_config['op_outputs'],
                        attrs=op_config['op_attrs'],
                        outputs_dtype=op_config['outputs_dtype'],
                    )
                )
            else:
                ops.append(
                    OpConfig(
                        type=op_config['op_type'],
                        inputs=op_config['op_inputs'],
                        outputs=op_config['op_outputs'],
                        attrs=op_config['op_attrs'],
                    )
205
                )
W
Wilber 已提交
206 207 208
        return ops

    @abc.abstractmethod
209
    def ignore_log(self, msg: str):
W
Wilber 已提交
210 211 212 213
        logging.warning("SKIP: " + msg)

    @abc.abstractmethod
    def fail_log(self, msg: str):
214
        logging.error("FAIL: " + msg)
W
Wilber 已提交
215 216 217 218 219 220

    @abc.abstractmethod
    def success_log(self, msg: str):
        logging.info("SUCCESS: " + msg)

    @abc.abstractmethod
221 222 223 224 225 226 227
    def create_inference_config(
        self,
        passes: Optional[List[str]] = None,
        use_gpu: bool = False,
        use_mkldnn: bool = False,
        ir_optim: Optional[bool] = None,
    ):
W
Wilber 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        config = paddle_infer.Config()
        config.switch_ir_debug(True)
        config.set_optim_cache_dir(self.cache_dir)
        config.disable_glog_info()
        if ir_optim is not None:
            config.switch_ir_optim(ir_optim)
        if use_gpu:
            config.enable_use_gpu(100, 0)
        if use_mkldnn:
            config.enable_mkldnn()
        if passes is not None:
            config.pass_builder().set_passes(passes)
            self.passes = passes
        return config


class MkldnnAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
246
        super().__init__(*args, **kwargs)
W
Wilber 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
264
                    'lod': tensor_config.lod,
W
Wilber 已提交
265 266 267 268 269 270 271
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: cpu no ir_optim run
            base_config = self.create_inference_config(ir_optim=False)
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
272 273 274 275
                self.run_test_config(
                    model, params, prog_config, base_config, feed_data
                )
            )
W
Wilber 已提交
276 277
            self.success_log('RUN_CPU_BASELINE done')

278 279 280
            for pred_config, (atol, rtol) in self.sample_predictor_configs(
                prog_config
            ):
W
Wilber 已提交
281
                # skip info
282 283 284 285
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
286 287 288 289
                        if (
                            ignore_info[1]
                            == IgnoreReasons.MKLDNN_ACCURACY_ERROR
                        ):
290
                            self.ignore_log(
291 292 293 294 295 296
                                "[MKLDNN_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
W
Wilber 已提交
297 298 299 300 301 302 303 304 305 306 307
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
308 309 310 311 312 313 314
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
W
Wilber 已提交
315 316
                except Exception as e:
                    self.fail_log(
317 318 319
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
320
                    if not ignore_flag:
W
Wilber 已提交
321 322
                        status = False
                    continue
323 324 325 326 327
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )
W
Wilber 已提交
328 329 330

        self.assertTrue(status)

W
Wilber 已提交
331
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
332 333 334 335 336 337 338 339 340 341
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)


class PassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
342
        super().__init__(*args, **kwargs)
W
Wilber 已提交
343 344 345 346 347
        self.passes = []

    def check_op_version(self):
        status = True
        for pass_name in self.passes:
J
Jason 已提交
348 349
            if pass_name not in self.available_passes_in_framework:
                continue
W
Wilber 已提交
350 351 352 353 354
            if not PassVersionChecker.IsCompatible(pass_name):
                self.fail_log('{} version check failed.'.format(pass_name))
                status = False
        return status

355
    def add_ignore_pass_case(self):
J
Jason 已提交
356 357 358
        return

    def assert_op_list(self, op_list_after_fusion):
W
Wilber 已提交
359 360
        if not self.passes:
            raise ValueError(
361 362 363 364 365
                "In PassAutoScan you should give a valid pass name."
            )
        last_passed_program = os.path.join(
            self.cache_dir, self.passes[-1] + ".pdmodel"
        )
J
Jason 已提交
366 367
        if not os.path.exists(last_passed_program):
            raise ValueError(
368 369 370 371
                "Cannot find file {}, please make sure that your pass name is correct".format(
                    last_passed_program
                )
            )
W
Wilber 已提交
372 373 374
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
J
Jason 已提交
375 376 377 378 379 380 381 382
        after_op_list = list()
        for i in range(main_block.op_size()):
            if main_block.op(i).type() in ["feed", "fetch"]:
                continue
            after_op_list.append(main_block.op(i).type())
        self.assertTrue(
            op_list_after_fusion == after_op_list,
            "Expected operator list after fusion is {}, but now it's {}".format(
383 384
                op_list_after_fusion, after_op_list
            ),
385
        )
W
Wilber 已提交
386

387 388 389 390 391 392 393 394 395
    def run_and_statis(
        self,
        quant=False,
        max_examples=100,
        reproduce=None,
        min_success_num=25,
        max_duration=180,
        passes=None,
    ):
J
Jason 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408
        if os.getenv('HYPOTHESIS_TEST_PROFILE', 'ci') == "dev":
            max_examples *= 10
            min_success_num *= 10
            # while at ce phase, there's no limit on time
            max_duration = -1
        start_time = time.time()
        settings.register_profile(
            "ci",
            max_examples=max_examples,
            suppress_health_check=hypothesis.HealthCheck.all(),
            deadline=None,
            print_blob=True,
            derandomize=True,
409 410
            report_multiple_bugs=False,
        )
J
Jason 已提交
411
        settings.load_profile("ci")
412 413 414
        assert (
            passes is not None
        ), "Parameter of passes must be defined in function run_and_statis."
J
Jason 已提交
415 416
        self.passes = passes

417
        self.add_ignore_pass_case()
J
Jason 已提交
418 419 420 421 422

        def program_generator(draw):
            return self.sample_program_config(draw)

        def run_test(prog_config):
423
            return self.run_test(quant=quant, prog_configs=[prog_config])
J
Jason 已提交
424 425 426 427 428 429 430 431

        generator = st.composite(program_generator)
        loop_func = given(generator())(run_test)
        if reproduce is not None:
            loop_func = reproduce(loop_func)
        logging.info("Start to running test of {}".format(type(self)))
        loop_func()
        logging.info(
432 433 434 435 436 437 438 439 440 441
            "===================Statistical Information==================="
        )
        logging.info(
            "Number of Generated Programs: {}".format(
                self.num_ran_programs + self.num_invalid_programs
            )
        )
        logging.info(
            "Number of Invalid Programs: {}".format(self.num_invalid_programs)
        )
J
Jason 已提交
442
        logging.info("Number of Ran Programs: {}".format(self.num_ran_programs))
443
        logging.info("Number of Ignore Tests: {}".format(self.num_ignore_tests))
444 445 446 447
        successful_ran_programs = int(
            self.num_ran_programs
            - self.num_ignore_tests / max(self.num_predictor_kinds, 1)
        )
J
Jason 已提交
448
        logging.info(
449 450 451 452
            "Number of successfully ran programs approximately equal to {}".format(
                successful_ran_programs
            )
        )
J
Jason 已提交
453 454
        if successful_ran_programs < min_success_num:
            logging.warning(
455
                "satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
J
Jason 已提交
456 457
            )
            logging.error(
458 459 460 461
                "At least {} programs need to ran successfully, but now only about {} programs satisfied.".format(
                    min_success_num, successful_ran_programs
                )
            )
J
Jason 已提交
462 463 464 465
            assert False
        used_time = time.time() - start_time
        if max_duration > 0 and used_time > max_duration:
            logging.error(
466 467 468 469
                "The duration exceeds {} seconds, if this is necessary, try to set a larger number for parameter `max_duration`.".format(
                    max_duration
                )
            )
J
Jason 已提交
470 471
            assert False

472
    def run_test(self, quant=False, prog_configs=None):
W
Wilber 已提交
473 474
        status = True

J
Jason 已提交
475
        for prog_config in prog_configs:
W
Wilber 已提交
476 477
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
J
Jason 已提交
478
                self.num_invalid_programs += 1
W
Wilber 已提交
479
                continue
J
Jason 已提交
480
            self.num_ran_programs += 1
W
Wilber 已提交
481 482 483 484 485 486 487 488
            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
489
                    'lod': tensor_config.lod,
W
Wilber 已提交
490
                }
491

W
Wilber 已提交
492
            logging.info('RUN program_config: ' + str(prog_config))
J
Jason 已提交
493
            self.num_predictor_kinds = 0
494 495 496 497 498
            for (
                pred_config,
                op_list,
                (atol, rtol),
            ) in self.sample_predictor_configs(prog_config):
J
Jason 已提交
499
                self.num_predictor_kinds += 1
500

W
Wilber 已提交
501
                # skip info
502 503 504 505 506 507
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        self.num_ignore_tests += 1
                        if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
508
                            self.ignore_log(
509 510 511 512 513 514
                                "[PASS_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
W
Wilber 已提交
515 516 517 518 519 520 521 522 523
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

524 525
                # baseline: no ir_optim run
                base_config = self.create_inference_config(
526 527
                    ir_optim=False, use_gpu=pred_config.use_gpu()
                )
W
Wilber 已提交
528
                try:
529
                    # baseline
530 531 532 533 534 535 536 537
                    base_result = self.run_test_config(
                        model, params, prog_config, base_config, feed_data
                    )
                    self.success_log(
                        'RUN_BASELINE '
                        + self.inference_config_str(base_config)
                        + ' done'
                    )
538 539 540 541

                    if os.path.exists(self.cache_dir):
                        shutil.rmtree(self.cache_dir)

542 543 544 545 546 547
                    pred_result = self.run_test_config(
                        model, params, prog_config, pred_config, feed_data
                    )
                    self.assert_tensors_near(
                        atol, rtol, pred_result, base_result
                    )
548
                    if not ignore_flag:
J
Jason 已提交
549
                        self.assert_op_list(op_list)
W
Wilber 已提交
550 551 552

                except Exception as e:
                    self.fail_log(
553 554 555
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
556
                    if not ignore_flag:
W
Wilber 已提交
557 558
                        status = False
                    continue
559 560 561 562 563
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )
W
Wilber 已提交
564 565 566 567

        status = self.check_op_version() and status
        self.assertTrue(status)

W
Wilber 已提交
568
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
        dic['use_mkldnn'] = enable_mkldnn
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        if not self.passes:
            dic['passes'] = self.passes

        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

    def create_trt_inference_config(self) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        config.switch_ir_debug()
        return config


class TrtLayerAutoScanTest(AutoScanTest):
    class TensorRTParam:
        '''
600
        TensorRT subgraph engine parameters.
W
Wilber 已提交
601 602
        '''

603 604 605 606 607 608 609 610 611
        def __init__(
            self,
            workspace_size,
            max_batch_size,
            min_subgraph_size,
            precision,
            use_static,
            use_calib_mode,
        ):
W
Wilber 已提交
612 613 614 615 616 617 618 619 620
            self.workspace_size = workspace_size
            self.max_batch_size = max_batch_size
            self.min_subgraph_size = min_subgraph_size
            self.precision = precision
            self.use_static = use_static
            self.use_calib_mode = use_calib_mode

    class DynamicShapeParam:
        '''
621 622
        Prepare TensorRT subgraph engine dynamic shape parameters.
        '''
W
Wilber 已提交
623

624 625 626 627 628 629 630
        def __init__(
            self,
            min_input_shape,
            max_input_shape,
            opt_input_shape,
            disable_trt_plugin_fp16,
        ):
W
Wilber 已提交
631 632 633 634 635 636
            self.min_input_shape = min_input_shape
            self.max_input_shape = max_input_shape
            self.opt_input_shape = opt_input_shape
            self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16

    def __init__(self, *args, **kwargs):
637
        super().__init__(*args, **kwargs)
W
Wilber 已提交
638 639 640 641 642 643
        self.trt_param = self.TensorRTParam(
            workspace_size=1024,
            max_batch_size=4,
            min_subgraph_size=0,
            precision=paddle_infer.PrecisionType.Float32,
            use_static=True,
644 645
            use_calib_mode=False,
        )
W
Wilber 已提交
646 647
        self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
        self.num_percent_cases = float(
648 649
            os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')
        )
Z
zlsh80826 已提交
650 651 652

        # Use a seperate random generator for skipping tests
        self.skip_rng = np.random.default_rng(int(time.strftime("%W")))
W
Wilber 已提交
653 654 655 656 657 658 659 660 661 662 663 664 665 666

    def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        if use_trt:
            config.switch_ir_debug()
            config.enable_tensorrt_engine(
                max_batch_size=self.trt_param.max_batch_size,
                workspace_size=self.trt_param.workspace_size,
                min_subgraph_size=self.trt_param.min_subgraph_size,
                precision_mode=self.trt_param.precision,
                use_static=self.trt_param.use_static,
667 668
                use_calib_mode=self.trt_param.use_calib_mode,
            )
Z
zlsh80826 已提交
669
            if self.dynamic_shape.min_input_shape and (
670 671 672 673
                self.dynamic_shape.min_input_shape.keys()
                == self.dynamic_shape.max_input_shape.keys()
                == self.dynamic_shape.opt_input_shape.keys()
            ):
W
Wilber 已提交
674 675 676 677
                config.set_trt_dynamic_shape_info(
                    self.dynamic_shape.min_input_shape,
                    self.dynamic_shape.max_input_shape,
                    self.dynamic_shape.opt_input_shape,
678 679
                    self.dynamic_shape.disable_trt_plugin_fp16,
                )
W
Wilber 已提交
680 681
        return config

682 683 684 685 686 687 688
    def assert_tensors_near(
        self,
        atol: float,
        rtol: float,
        tensor: Dict[str, np.array],
        baseline: Dict[str, np.array],
    ):
Z
zlsh80826 已提交
689 690
        for key, arr in tensor.items():
            self.assertEqual(
691 692 693 694 695 696 697
                baseline[key].shape,
                arr.shape,
                'The output shapes are not equal, the baseline shape is '
                + str(baseline[key].shape)
                + ', but got '
                + str(arr.shape),
            )
Z
zlsh80826 已提交
698 699
            np.testing.assert_allclose(baseline[key], arr, rtol=rtol, atol=atol)

W
Wilber 已提交
700 701
    def assert_op_size(self, trt_engine_num, paddle_op_num):
        last_passed_program = os.path.join(
702 703
            self.cache_dir, 'transpose_flatten_concat_fuse_pass.pdmodel'
        )
W
Wilber 已提交
704 705 706 707 708 709 710 711 712
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
        op_size = main_block.op_size()
        op_types = [
            main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size)
        ]
        trt_engine_size = sum(op_types)
        paddle_op_size = op_size - trt_engine_size
Z
zlsh80826 已提交
713
        self.assertEqual(
714 715
            trt_engine_num,
            trt_engine_size,
Z
zlsh80826 已提交
716
            'Expected trt_engine_num is {}, but got {}!'.format(
717 718 719
                trt_engine_num, trt_engine_size
            ),
        )
Z
zlsh80826 已提交
720
        self.assertEqual(
721 722
            paddle_op_num,
            paddle_op_size,
Z
zlsh80826 已提交
723
            'Expected paddle_op_num is {}, but got {}!'.format(
724 725 726
                paddle_op_num, paddle_op_size
            ),
        )
W
Wilber 已提交
727

W
Wilber 已提交
728
    def inference_config_str(self, config: paddle_infer.Config) -> str:
W
Wilber 已提交
729 730 731 732 733 734 735 736 737 738 739 740
        dic = {}
        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
            dic['use_trt'] = True
            dic['trt_precision'] = trt_precison
            dic['use_dynamic_shape'] = trt_dynamic_shape
        else:
            dic['use_trt'] = False
        return str(dic)

741
    def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
Z
zlsh80826 已提交
742 743 744 745 746 747 748
        all_passes = True

        def random_to_skip():
            if self.skip_rng.random() < self.num_percent_cases:
                return False
            return True

W
Wilber 已提交
749
        for prog_config in self.sample_program_configs(*args, **kwargs):
Z
zlsh80826 已提交
750 751

            if random_to_skip():
W
Wilber 已提交
752 753 754 755 756 757 758 759 760 761 762 763 764 765
                continue

            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
766
                    'lod': tensor_config.lod,
W
Wilber 已提交
767 768 769
                }

            results: List[Dict[str, np.ndarray]] = []
770
            if not skip_baseline:
771
                # baseline: gpu run
772 773 774
                logging.info('RUN program_config: ' + str(prog_config))
                gpu_config = self.create_inference_config(use_trt=False)
                results.append(
775 776 777 778
                    self.run_test_config(
                        model, params, prog_config, gpu_config, feed_data
                    )
                )
779
                self.success_log('RUN_GPU_BASELINE done')
W
Wilber 已提交
780

781 782 783 784 785
            for (
                pred_config,
                nodes_num,
                threshold,
            ) in self.sample_predictor_configs(prog_config):
W
Wilber 已提交
786 787 788 789 790 791 792

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)

                if isinstance(threshold, float):
                    atol = threshold
                    rtol = 1e-8
793
                elif isinstance(threshold, list) or isinstance(
794 795
                    threshold, tuple
                ):
W
Wilber 已提交
796 797 798 799 800
                    atol = threshold[0]
                    rtol = threshold[1]
                else:
                    raise NotImplementedError

801 802 803 804 805
                if (
                    pred_config.tensorrt_precision_mode()
                    != paddle_infer.PrecisionType.Int8
                    and quant
                ):
W
Wilber 已提交
806
                    continue
807 808 809 810 811
                if (
                    pred_config.tensorrt_precision_mode()
                    == paddle_infer.PrecisionType.Int8
                    and not quant
                ):
W
Wilber 已提交
812 813
                    continue

814
                ignore_flag = False
Z
zlsh80826 已提交
815 816
                for teller, reason, note in self.ignore_cases:
                    if teller(prog_config, pred_config):
817
                        ignore_flag = True
Z
zlsh80826 已提交
818
                        if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
819
                            self.ignore_log(
Z
zlsh80826 已提交
820
                                '[TRT_NOT_IMPLEMENTED] {} vs {}'.format(
821 822 823
                                    note, self.inference_config_str(pred_config)
                                )
                            )
Z
zlsh80826 已提交
824
                        elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
825 826 827 828 829
                            self.ignore_log(
                                '[TRT_NOT_SUPPORT] {} vs {}'.format(
                                    note, self.inference_config_str(pred_config)
                                )
                            )
W
Wilber 已提交
830 831 832 833
                        else:
                            raise NotImplementedError
                        break

Z
zlsh80826 已提交
834 835 836
                if ignore_flag:
                    continue

W
Wilber 已提交
837 838 839
                try:
                    pred_config_deserialize = paddle_infer.Config(pred_config)
                    results.append(
840 841 842 843 844 845 846
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
Z
zlsh80826 已提交
847 848 849
                    trt_engine_num, paddle_op_num = nodes_num
                    self.assert_op_size(trt_engine_num, paddle_op_num)

W
Wilber 已提交
850
                    # deserialize test
Z
zlsh80826 已提交
851
                    if trt_engine_num > 0:
852 853 854 855 856 857 858 859 860 861 862 863 864
                        self.run_test_config(
                            model,
                            params,
                            prog_config,
                            pred_config_deserialize,
                            feed_data,
                        )

                    self.success_log(
                        'RUN predictor_config {} done'.format(
                            self.inference_config_str(pred_config)
                        )
                    )
W
Wilber 已提交
865 866
                except Exception as e:
                    self.fail_log(
867 868 869
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
Z
zlsh80826 已提交
870
                    all_passes = False
W
Wilber 已提交
871

Z
zlsh80826 已提交
872
        self.assertTrue(all_passes)
873 874

    # TODO(wilber): just for backward compatible
875 876 877 878 879 880
    def add_skip_case(
        self,
        teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
        reason: IgnoreReasons,
        note: str,
    ):
881
        self.ignore_cases.append((teller, reason, note))
882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974


class CutlassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod,
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: gpu no ir_optim run
            base_config = self.create_inference_config(
                ir_optim=False, use_gpu=True
            )
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
                self.run_test_config(
                    model, params, prog_config, base_config, feed_data
                )
            )
            self.success_log('RUN_GPU_BASELINE done')

            for pred_config, (atol, rtol) in self.sample_predictor_configs(
                prog_config
            ):
                # skip info
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        if (
                            ignore_info[1]
                            == IgnoreReasons.CUTLASS_ACCURACY_ERROR
                        ):
                            self.ignore_log(
                                "[CUTLASS_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config)
                        + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
                    )
                    if not ignore_flag:
                        status = False
                    continue
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )

        self.assertTrue(status)

    def inference_config_str(self, config) -> str:
        dic = {}
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)