auto_scan_test.py 33.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
16
import enum
W
Wilber 已提交
17
import logging
18
import os
W
Wilber 已提交
19
import shutil
20 21 22 23 24 25 26 27
import time
import unittest
from typing import Any, Callable, Dict, List, Optional

import hypothesis
import hypothesis.strategies as st
import numpy as np
from hypothesis import given, settings
28 29 30 31 32 33
from program_config import (
    OpConfig,
    ProgramConfig,
    create_fake_model,
    create_quant_model,
)
34

35 36 37
import paddle
import paddle.inference as paddle_infer
from paddle.fluid.core import PassVersionChecker
W
Wilber 已提交
38

39 40
LOGLEVEL = os.environ.get("PADDLE_TEST_LOGLEVEL", "INFO").upper()
logging.basicConfig(level=LOGLEVEL, format="%(message)s")
W
Wilber 已提交
41

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
settings.register_profile(
    "ci",
    max_examples=100,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False,
)
settings.register_profile(
    "dev",
    max_examples=1000,
    suppress_health_check=hypothesis.HealthCheck.all(),
    deadline=None,
    print_blob=True,
    derandomize=True,
    report_multiple_bugs=False,
)
if (
61 62
    float(os.getenv("TEST_NUM_PERCENT_CASES", default="1.0")) < 1
    or os.getenv("HYPOTHESIS_TEST_PROFILE", "dev") == "ci"
63
):
W
Wilber 已提交
64 65 66 67
    settings.load_profile("ci")
else:
    settings.load_profile("dev")

68

69
class IgnoreReasons(enum.Enum):
70 71 72 73
    # Paddle not support, but trt support, we need to add the feature.
    TRT_NOT_IMPLEMENTED = 0
    # TRT not support.
    TRT_NOT_SUPPORT = 1
W
Wilber 已提交
74 75 76 77
    # Accuracy is abnormal after enabling pass.
    PASS_ACCURACY_ERROR = 2
    # Accuracy is abnormal after enabling mkldnn.
    MKLDNN_ACCURACY_ERROR = 3
78 79
    # Accuracy is abnormal after enabling cutlass.
    CUTLASS_ACCURACY_ERROR = 3
80 81


82 83 84 85
# TODO(wilber): just for backward compatible
SkipReasons = IgnoreReasons


86
class AutoScanTest(unittest.TestCase):
W
Wilber 已提交
87
    def __init__(self, *args, **kwargs):
W
Wilber 已提交
88
        np.random.seed(1024)
89
        paddle.enable_static()
90
        super().__init__(*args, **kwargs)
91
        self.ignore_cases = []
W
Wilber 已提交
92
        abs_dir = os.path.abspath(os.path.dirname(__file__))
93 94 95
        self.cache_dir = os.path.join(
            abs_dir, str(self.__module__) + '_cache_dir'
        )
J
Jason 已提交
96 97 98
        self.available_passes_in_framework = set()
        self.num_ran_programs = 0
        self.num_invalid_programs = 0
99
        self.num_ignore_tests = 0
J
Jason 已提交
100
        self.num_predictor_kinds = 0
101 102

    @abc.abstractmethod
W
Wilber 已提交
103
    def sample_program_configs(self):
104
        """
105 106
        Generate all config with the combination of different Input tensor shape and
        different Attr values.
107
        """
108 109 110
        raise NotImplementedError

    @abc.abstractmethod
W
Wilber 已提交
111
    def sample_predictor_configs(self):
112 113
        raise NotImplementedError

114
    @abc.abstractmethod
115 116 117 118 119 120
    def add_ignore_check_case(
        self,
        teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
        reason: IgnoreReasons,
        note: str,
    ):
121
        self.ignore_cases.append((teller, reason, note))
122

W
Wilber 已提交
123
    def is_program_valid(self, program_config: ProgramConfig) -> bool:
J
Jason 已提交
124
        return True
125

126 127 128
    def run_test_config(
        self, model, params, prog_config, pred_config, feed_data
    ) -> Dict[str, np.ndarray]:
129
        """
130
        Test a single case.
131
        """
132 133
        pred_config.set_model_buffer(model, len(model), params, len(params))
        predictor = paddle_infer.create_predictor(pred_config)
134 135 136 137
        self.available_passes_in_framework = (
            self.available_passes_in_framework
            | set(pred_config.pass_builder().all_passes())
        )
138 139
        for name, _ in prog_config.inputs.items():
            input_tensor = predictor.get_input_handle(name)
140 141 142
            input_tensor.copy_from_cpu(feed_data[name]["data"])
            if feed_data[name]["lod"] is not None:
                input_tensor.set_lod(feed_data[name]["lod"])
143 144
        predictor.run()
        result = {}
145 146 147
        for out_name, o_name in zip(
            prog_config.outputs, predictor.get_output_names()
        ):
148
            result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu()
149 150
        return result

W
Wilber 已提交
151
    @abc.abstractmethod
152 153 154 155 156 157 158
    def assert_tensors_near(
        self,
        atol: float,
        rtol: float,
        tensor: Dict[str, np.array],
        baseline: Dict[str, np.array],
    ):
W
Wilber 已提交
159 160 161
        for key, arr in tensor.items():
            self.assertTrue(
                baseline[key].shape == arr.shape,
162
                f"The output shapes are not equal, the baseline shape is {baseline[key].shape}, but got {str(arr.shape)}",
163
            )
164
            diff = abs(baseline[key] - arr)
165 166 167 168 169
            np.testing.assert_allclose(
                baseline[key],
                arr,
                rtol=rtol,
                atol=atol,
170
                err_msg=f"Output has diff, Maximum absolute error: {np.amax(diff)}",
171
            )
172

173 174 175
    @abc.abstractmethod
    def run_test(self, quant=False):
        raise NotImplementedError
W
Wilber 已提交
176

177 178 179
    def generate_op_config(
        self, ops_config: List[Dict[str, Any]]
    ) -> List[OpConfig]:
W
Wilber 已提交
180 181 182
        ops = []
        for i in range(len(ops_config)):
            op_config = ops_config[i]
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
            if 'outputs_dtype' in op_config:
                ops.append(
                    OpConfig(
                        type=op_config['op_type'],
                        inputs=op_config['op_inputs'],
                        outputs=op_config['op_outputs'],
                        attrs=op_config['op_attrs'],
                        outputs_dtype=op_config['outputs_dtype'],
                    )
                )
            else:
                ops.append(
                    OpConfig(
                        type=op_config['op_type'],
                        inputs=op_config['op_inputs'],
                        outputs=op_config['op_outputs'],
                        attrs=op_config['op_attrs'],
                    )
201
                )
W
Wilber 已提交
202 203 204
        return ops

    @abc.abstractmethod
205
    def ignore_log(self, msg: str):
206
        logging.debug(f"SKIP: {msg}")
W
Wilber 已提交
207 208 209

    @abc.abstractmethod
    def fail_log(self, msg: str):
210 211 212 213 214
        logging.error(f"FAIL: {msg}")

    @abc.abstractmethod
    def info_log(self, msg: str):
        logging.debug(f"INFO: {msg}")
W
Wilber 已提交
215 216 217

    @abc.abstractmethod
    def success_log(self, msg: str):
218
        logging.debug(f"SUCCESS: {msg}")
W
Wilber 已提交
219 220

    @abc.abstractmethod
221 222 223 224 225
    def create_inference_config(
        self,
        passes: Optional[List[str]] = None,
        use_gpu: bool = False,
        use_mkldnn: bool = False,
226
        use_xpu: bool = False,
227 228
        ir_optim: Optional[bool] = None,
    ):
W
Wilber 已提交
229 230 231 232 233 234 235 236 237 238
        config = paddle_infer.Config()
        config.switch_ir_debug(True)
        config.set_optim_cache_dir(self.cache_dir)
        config.disable_glog_info()
        if ir_optim is not None:
            config.switch_ir_optim(ir_optim)
        if use_gpu:
            config.enable_use_gpu(100, 0)
        if use_mkldnn:
            config.enable_mkldnn()
239 240
        if use_xpu:
            config.enable_xpu()
W
Wilber 已提交
241 242 243 244 245 246 247 248
        if passes is not None:
            config.pass_builder().set_passes(passes)
            self.passes = passes
        return config


class MkldnnAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
249
        super().__init__(*args, **kwargs)
W
Wilber 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
266 267
                    "data": tensor_config.data,
                    "lod": tensor_config.lod,
W
Wilber 已提交
268 269 270 271 272 273
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: cpu no ir_optim run
            base_config = self.create_inference_config(ir_optim=False)
            results.append(
274 275 276 277
                self.run_test_config(
                    model, params, prog_config, base_config, feed_data
                )
            )
278 279 280 281
            self.success_log(f"basline program_config: {prog_config}")
            self.success_log(
                f"basline predictor_config: {self.inference_config_str(base_config)}"
            )
W
Wilber 已提交
282

283 284 285
            for pred_config, (atol, rtol) in self.sample_predictor_configs(
                prog_config
            ):
W
Wilber 已提交
286
                # skip info
287 288 289 290
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
291 292 293 294
                        if (
                            ignore_info[1]
                            == IgnoreReasons.MKLDNN_ACCURACY_ERROR
                        ):
295
                            self.ignore_log(
296
                                f"[MKLDNN_ACCURACY_ERROR] {ignore_info[2]} vs {self.inference_config_str(pred_config)}"
297
                            )
W
Wilber 已提交
298 299 300 301 302 303 304 305 306 307 308
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
309 310 311 312 313 314 315
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
316 317 318 319 320

                    self.success_log(f"program_config: {prog_config}")
                    self.success_log(
                        f"predictor_config: {self.inference_config_str(pred_config)}"
                    )
W
Wilber 已提交
321
                except Exception as e:
322
                    self.fail_log(f"program_config: {prog_config}")
W
Wilber 已提交
323
                    self.fail_log(
324
                        f"predictor_config: {self.inference_config_str(pred_config)}"
325
                    )
326
                    self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
327
                    if not ignore_flag:
W
Wilber 已提交
328 329 330 331 332
                        status = False
                    continue

        self.assertTrue(status)

W
Wilber 已提交
333
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
334 335
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
336
        dic["use_mkldnn"] = enable_mkldnn
W
Wilber 已提交
337
        enable_gpu = config.use_gpu()
338
        dic["use_gpu"] = enable_gpu
W
Wilber 已提交
339 340 341 342 343
        return str(dic)


class PassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
344
        super().__init__(*args, **kwargs)
W
Wilber 已提交
345 346 347 348 349
        self.passes = []

    def check_op_version(self):
        status = True
        for pass_name in self.passes:
J
Jason 已提交
350 351
            if pass_name not in self.available_passes_in_framework:
                continue
W
Wilber 已提交
352
            if not PassVersionChecker.IsCompatible(pass_name):
353
                self.fail_log(f"{pass_name} version check failed.")
W
Wilber 已提交
354 355 356
                status = False
        return status

357
    def add_ignore_pass_case(self):
J
Jason 已提交
358 359 360
        return

    def assert_op_list(self, op_list_after_fusion):
W
Wilber 已提交
361 362
        if not self.passes:
            raise ValueError(
363 364 365 366 367
                "In PassAutoScan you should give a valid pass name."
            )
        last_passed_program = os.path.join(
            self.cache_dir, self.passes[-1] + ".pdmodel"
        )
J
Jason 已提交
368 369
        if not os.path.exists(last_passed_program):
            raise ValueError(
370
                f"Cannot find file {last_passed_program}, please make sure that your pass name is correct"
371
            )
W
Wilber 已提交
372 373 374
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
375
        after_op_list = []
J
Jason 已提交
376 377 378 379 380 381
        for i in range(main_block.op_size()):
            if main_block.op(i).type() in ["feed", "fetch"]:
                continue
            after_op_list.append(main_block.op(i).type())
        self.assertTrue(
            op_list_after_fusion == after_op_list,
382
            f"Expected operator list after fusion is {op_list_after_fusion}, but now it's {after_op_list}",
383
        )
W
Wilber 已提交
384

385 386 387 388 389 390 391 392 393
    def run_and_statis(
        self,
        quant=False,
        max_examples=100,
        reproduce=None,
        min_success_num=25,
        max_duration=180,
        passes=None,
    ):
394
        if os.getenv("HYPOTHESIS_TEST_PROFILE", "ci") == "dev":
J
Jason 已提交
395 396
            max_examples *= 10
            min_success_num *= 10
397
            # while at ce phase, there"s no limit on time
J
Jason 已提交
398 399 400 401 402 403 404 405 406
            max_duration = -1
        start_time = time.time()
        settings.register_profile(
            "ci",
            max_examples=max_examples,
            suppress_health_check=hypothesis.HealthCheck.all(),
            deadline=None,
            print_blob=True,
            derandomize=True,
407 408
            report_multiple_bugs=False,
        )
J
Jason 已提交
409
        settings.load_profile("ci")
410 411 412
        assert (
            passes is not None
        ), "Parameter of passes must be defined in function run_and_statis."
J
Jason 已提交
413 414
        self.passes = passes

415
        self.add_ignore_pass_case()
J
Jason 已提交
416 417 418 419 420

        def program_generator(draw):
            return self.sample_program_config(draw)

        def run_test(prog_config):
421
            return self.run_test(quant=quant, prog_configs=[prog_config])
J
Jason 已提交
422 423 424 425 426

        generator = st.composite(program_generator)
        loop_func = given(generator())(run_test)
        if reproduce is not None:
            loop_func = reproduce(loop_func)
427
        logging.info(f"Start to running test of {type(self)}")
J
Jason 已提交
428
        loop_func()
429
        self.info_log(
430 431
            "===================Statistical Information==================="
        )
432 433
        self.info_log(
            f"Number of Generated Programs: {self.num_ran_programs + self.num_invalid_programs}"
434
        )
435 436 437
        logging.info(f"Number of Invalid Programs: {self.num_invalid_programs}")
        logging.info(f"Number of Ran Programs: {self.num_ran_programs}")
        logging.info(f"Number of Ignore Tests: {self.num_ignore_tests}")
438 439 440 441
        successful_ran_programs = int(
            self.num_ran_programs
            - self.num_ignore_tests / max(self.num_predictor_kinds, 1)
        )
442 443
        self.info_log(
            f"Number of successfully ran programs approximately equal to {successful_ran_programs}"
444
        )
J
Jason 已提交
445
        if successful_ran_programs < min_success_num:
446
            self.fail_log(
447
                "satisfied_programs = ran_programs - num_ignore_tests / num_predictor_kinds"
J
Jason 已提交
448
            )
449 450
            self.fail_log(
                f"At least {min_success_num} programs need to ran successfully, but now only about {successful_ran_programs} programs satisfied."
451
            )
452
            raise AssertionError()
J
Jason 已提交
453 454
        used_time = time.time() - start_time
        if max_duration > 0 and used_time > max_duration:
455 456
            self.fail_log(
                f"The duration exceeds {max_duration} seconds, if this is necessary, try to set a larger number for parameter `max_duration`."
457
            )
458
            raise AssertionError()
J
Jason 已提交
459

460
    def run_test(self, quant=False, prog_configs=None):
W
Wilber 已提交
461 462
        status = True

J
Jason 已提交
463
        for prog_config in prog_configs:
W
Wilber 已提交
464 465
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
J
Jason 已提交
466
                self.num_invalid_programs += 1
W
Wilber 已提交
467
                continue
J
Jason 已提交
468
            self.num_ran_programs += 1
W
Wilber 已提交
469 470 471 472 473 474 475
            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
476 477
                    "data": tensor_config.data,
                    "lod": tensor_config.lod,
W
Wilber 已提交
478
                }
479

J
Jason 已提交
480
            self.num_predictor_kinds = 0
481 482 483 484 485
            for (
                pred_config,
                op_list,
                (atol, rtol),
            ) in self.sample_predictor_configs(prog_config):
J
Jason 已提交
486
                self.num_predictor_kinds += 1
487

W
Wilber 已提交
488
                # skip info
489 490 491 492 493 494
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        self.num_ignore_tests += 1
                        if ignore_info[1] == IgnoreReasons.PASS_ACCURACY_ERROR:
495
                            self.ignore_log(
496
                                f"[PASS_ACCURACY_ERROR] {ignore_info[2]} vs {self.inference_config_str(pred_config)}"
497
                            )
W
Wilber 已提交
498 499 500 501 502 503 504 505 506
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

507 508
                # baseline: no ir_optim run
                base_config = self.create_inference_config(
509 510
                    ir_optim=False, use_gpu=pred_config.use_gpu()
                )
W
Wilber 已提交
511
                try:
512
                    # baseline
513 514 515 516
                    base_result = self.run_test_config(
                        model, params, prog_config, base_config, feed_data
                    )
                    self.success_log(
517
                        f"baseline program_config: {self.inference_config_str(base_config)}"
518
                    )
519 520 521 522

                    if os.path.exists(self.cache_dir):
                        shutil.rmtree(self.cache_dir)

523 524 525 526 527 528
                    pred_result = self.run_test_config(
                        model, params, prog_config, pred_config, feed_data
                    )
                    self.assert_tensors_near(
                        atol, rtol, pred_result, base_result
                    )
529
                    if not ignore_flag:
J
Jason 已提交
530
                        self.assert_op_list(op_list)
W
Wilber 已提交
531

532 533 534 535
                    self.success_log(f"program_config: {prog_config}")
                    self.success_log(
                        f"predictor_config: {self.inference_config_str(pred_config)}"
                    )
W
Wilber 已提交
536
                except Exception as e:
537
                    self.fail_log(f"program_config: {prog_config}")
W
Wilber 已提交
538
                    self.fail_log(
539
                        f"predictor_config: {self.inference_config_str(pred_config)}"
540
                    )
541
                    self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
542
                    if not ignore_flag:
W
Wilber 已提交
543 544 545 546 547 548
                        status = False
                    continue

        status = self.check_op_version() and status
        self.assertTrue(status)

W
Wilber 已提交
549
    def inference_config_str(self, config) -> str:
W
Wilber 已提交
550 551
        dic = {}
        enable_mkldnn = config.mkldnn_enabled()
552
        dic["use_mkldnn"] = enable_mkldnn
W
Wilber 已提交
553 554
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
555 556
        enable_xpu = config.use_xpu()
        dic['use_xpu'] = enable_xpu
W
Wilber 已提交
557
        if not self.passes:
558
            dic["passes"] = self.passes
W
Wilber 已提交
559 560 561 562 563

        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
564 565 566
            dic["use_trt"] = True
            dic["trt_precision"] = trt_precison
            dic["use_dynamic_shape"] = trt_dynamic_shape
W
Wilber 已提交
567
        else:
568
            dic["use_trt"] = False
W
Wilber 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581
        return str(dic)

    def create_trt_inference_config(self) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        config.switch_ir_debug()
        return config


class TrtLayerAutoScanTest(AutoScanTest):
    class TensorRTParam:
582
        """
583
        TensorRT subgraph engine parameters.
584
        """
W
Wilber 已提交
585

586 587 588 589 590 591 592 593 594
        def __init__(
            self,
            workspace_size,
            max_batch_size,
            min_subgraph_size,
            precision,
            use_static,
            use_calib_mode,
        ):
W
Wilber 已提交
595 596 597 598 599 600 601 602
            self.workspace_size = workspace_size
            self.max_batch_size = max_batch_size
            self.min_subgraph_size = min_subgraph_size
            self.precision = precision
            self.use_static = use_static
            self.use_calib_mode = use_calib_mode

    class DynamicShapeParam:
603
        """
604
        Prepare TensorRT subgraph engine dynamic shape parameters.
605
        """
W
Wilber 已提交
606

607 608 609 610 611 612 613
        def __init__(
            self,
            min_input_shape,
            max_input_shape,
            opt_input_shape,
            disable_trt_plugin_fp16,
        ):
W
Wilber 已提交
614 615 616 617 618 619
            self.min_input_shape = min_input_shape
            self.max_input_shape = max_input_shape
            self.opt_input_shape = opt_input_shape
            self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16

    def __init__(self, *args, **kwargs):
620
        super().__init__(*args, **kwargs)
W
Wilber 已提交
621 622 623 624 625 626
        self.trt_param = self.TensorRTParam(
            workspace_size=1024,
            max_batch_size=4,
            min_subgraph_size=0,
            precision=paddle_infer.PrecisionType.Float32,
            use_static=True,
627 628
            use_calib_mode=False,
        )
W
Wilber 已提交
629 630
        self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False)
        self.num_percent_cases = float(
631
            os.getenv("TEST_NUM_PERCENT_CASES", default="1.0")
632
        )
Z
zlsh80826 已提交
633

S
Shuangchi He 已提交
634
        # Use a separate random generator for skipping tests
Z
zlsh80826 已提交
635
        self.skip_rng = np.random.default_rng(int(time.strftime("%W")))
W
Wilber 已提交
636 637 638 639 640 641 642 643 644 645 646 647 648 649

    def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
        config = paddle_infer.Config()
        config.disable_glog_info()
        config.enable_use_gpu(100, 0)
        config.set_optim_cache_dir(self.cache_dir)
        if use_trt:
            config.switch_ir_debug()
            config.enable_tensorrt_engine(
                max_batch_size=self.trt_param.max_batch_size,
                workspace_size=self.trt_param.workspace_size,
                min_subgraph_size=self.trt_param.min_subgraph_size,
                precision_mode=self.trt_param.precision,
                use_static=self.trt_param.use_static,
650 651
                use_calib_mode=self.trt_param.use_calib_mode,
            )
Z
zlsh80826 已提交
652
            if self.dynamic_shape.min_input_shape and (
653 654 655 656
                self.dynamic_shape.min_input_shape.keys()
                == self.dynamic_shape.max_input_shape.keys()
                == self.dynamic_shape.opt_input_shape.keys()
            ):
W
Wilber 已提交
657 658 659 660
                config.set_trt_dynamic_shape_info(
                    self.dynamic_shape.min_input_shape,
                    self.dynamic_shape.max_input_shape,
                    self.dynamic_shape.opt_input_shape,
661 662
                    self.dynamic_shape.disable_trt_plugin_fp16,
                )
W
Wilber 已提交
663 664
        return config

665 666 667
    def get_avalible_input_type(self) -> List[np.dtype]:
        return [np.float32]

668 669 670 671 672 673 674
    def assert_tensors_near(
        self,
        atol: float,
        rtol: float,
        tensor: Dict[str, np.array],
        baseline: Dict[str, np.array],
    ):
Z
zlsh80826 已提交
675 676
        for key, arr in tensor.items():
            self.assertEqual(
677 678
                baseline[key].shape,
                arr.shape,
679
                f"The output shapes are not equal, the baseline shape is {baseline[key].shape}, but got {str(arr.shape)}",
680
            )
681
            np.testing.assert_allclose(arr, baseline[key], rtol=rtol, atol=atol)
Z
zlsh80826 已提交
682

W
Wilber 已提交
683 684
    def assert_op_size(self, trt_engine_num, paddle_op_num):
        last_passed_program = os.path.join(
685
            self.cache_dir, "transpose_flatten_concat_fuse_pass.pdmodel"
686
        )
W
Wilber 已提交
687 688 689 690 691
        model_bytes = paddle.static.load_from_file(last_passed_program)
        pg = paddle.static.deserialize_program(model_bytes)
        main_block = pg.desc.block(0)
        op_size = main_block.op_size()
        op_types = [
692
            main_block.op(i).type() == "tensorrt_engine" for i in range(op_size)
W
Wilber 已提交
693 694 695
        ]
        trt_engine_size = sum(op_types)
        paddle_op_size = op_size - trt_engine_size
Z
zlsh80826 已提交
696
        self.assertEqual(
697 698
            trt_engine_num,
            trt_engine_size,
699
            f"Expected trt_engine_num is {trt_engine_num}, but got {trt_engine_size}!",
700
        )
Z
zlsh80826 已提交
701
        self.assertEqual(
702 703
            paddle_op_num,
            paddle_op_size,
704
            f"Expected paddle_op_num is {paddle_op_num}, but got {paddle_op_size}!",
705
        )
W
Wilber 已提交
706

W
Wilber 已提交
707
    def inference_config_str(self, config: paddle_infer.Config) -> str:
W
Wilber 已提交
708 709 710 711 712
        dic = {}
        enable_trt = config.tensorrt_engine_enabled()
        trt_precison = config.tensorrt_precision_mode()
        trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled()
        if enable_trt:
713 714 715
            dic["use_trt"] = True
            dic["trt_precision"] = trt_precison
            dic["use_dynamic_shape"] = trt_dynamic_shape
W
Wilber 已提交
716
        else:
717
            dic["use_trt"] = False
W
Wilber 已提交
718 719
        return str(dic)

720
    def run_test(self, quant=False, skip_baseline=False, *args, **kwargs):
Z
zlsh80826 已提交
721 722 723 724 725 726 727
        all_passes = True

        def random_to_skip():
            if self.skip_rng.random() < self.num_percent_cases:
                return False
            return True

W
Wilber 已提交
728
        for prog_config in self.sample_program_configs(*args, **kwargs):
Z
zlsh80826 已提交
729 730

            if random_to_skip():
W
Wilber 已提交
731 732 733 734 735 736 737 738 739 740 741 742 743
                continue

            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            if quant:
                model, params = create_quant_model(model, params)

            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
744 745
                    "data": tensor_config.data,
                    "lod": tensor_config.lod,
W
Wilber 已提交
746 747
                }

748
            if not skip_baseline:
749
                # baseline: gpu run, we only test float32
750
                gpu_config = self.create_inference_config(use_trt=False)
751 752 753 754 755 756
                baseline_result = self.run_test_config(
                    model,
                    params,
                    prog_config.set_input_type(np.float32),
                    gpu_config,
                    feed_data,
757
                )
758
                self.success_log(f"basline program_config: {prog_config}")
W
Wilber 已提交
759

760 761 762 763 764
            for (
                pred_config,
                nodes_num,
                threshold,
            ) in self.sample_predictor_configs(prog_config):
765 766 767 768
                for input_type in self.get_avalible_input_type():
                    prog_config = prog_config.set_input_type(input_type)
                    if os.path.exists(self.cache_dir):
                        shutil.rmtree(self.cache_dir)
W
Wilber 已提交
769

770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
                    if isinstance(threshold, float):
                        atol = threshold
                        rtol = 1e-8
                    elif isinstance(threshold, list) or isinstance(
                        threshold, tuple
                    ):
                        atol = threshold[0]
                        rtol = threshold[1]
                    else:
                        raise NotImplementedError

                    is_fp8 = (
                        pred_config.tensorrt_precision_mode()
                        == paddle_infer.PrecisionType.Int8
                    )
                    if (not is_fp8 and quant) or (is_fp8 and not quant):
                        continue

                    ignore_flag = False
                    for teller, reason, note in self.ignore_cases:
                        if teller(prog_config, pred_config):
                            ignore_flag = True
                            if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
                                self.ignore_log(
                                    f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}"
795
                                )
796 797 798
                            elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
                                self.ignore_log(
                                    f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}"
799
                                )
800 801 802
                            else:
                                raise NotImplementedError
                            break
W
Wilber 已提交
803

804 805
                    if ignore_flag:
                        continue
Z
zlsh80826 已提交
806

807 808 809 810 811
                    try:
                        pred_config_deserialize = paddle_infer.Config(
                            pred_config
                        )
                        trt_result = self.run_test_config(
812 813
                            model, params, prog_config, pred_config, feed_data
                        )
814 815
                        self.assert_tensors_near(
                            atol, rtol, trt_result, baseline_result
816
                        )
817 818 819 820 821 822 823 824 825 826 827 828
                        trt_engine_num, paddle_op_num = nodes_num
                        self.assert_op_size(trt_engine_num, paddle_op_num)

                        # deserialize test
                        if trt_engine_num > 0:
                            self.run_test_config(
                                model,
                                params,
                                prog_config,
                                pred_config_deserialize,
                                feed_data,
                            )
829

830 831 832
                        self.success_log(f"program_config: {prog_config}")
                        self.success_log(
                            f"predictor_config: {self.inference_config_str(pred_config)}"
833
                        )
834 835 836 837 838 839 840
                    except Exception as e:
                        self.fail_log(f"program_config: {prog_config}")
                        self.fail_log(
                            f"predictor_config: {self.inference_config_str(pred_config)}"
                        )
                        self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
                        all_passes = False
W
Wilber 已提交
841

Z
zlsh80826 已提交
842
        self.assertTrue(all_passes)
843 844

    # TODO(wilber): just for backward compatible
845 846 847 848 849 850
    def add_skip_case(
        self,
        teller: [Callable[[ProgramConfig, paddle_infer.Config], bool]],
        reason: IgnoreReasons,
        note: str,
    ):
851
        self.ignore_cases.append((teller, reason, note))
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926


class CutlassAutoScanTest(AutoScanTest):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def run_test(self, quant=False, *args, **kwargs):
        status = True

        for prog_config in self.sample_program_configs(*args, **kwargs):
            # if program is invalid, we should skip that cases.
            if not self.is_program_valid(prog_config):
                continue

            model, params = create_fake_model(prog_config)
            feed_data = {}
            for name, tensor_config in prog_config.inputs.items():
                feed_data[name] = {
                    'data': tensor_config.data,
                    'lod': tensor_config.lod,
                }
            results: List[Dict[str, np.ndarray]] = []

            # baseline: gpu no ir_optim run
            base_config = self.create_inference_config(
                ir_optim=False, use_gpu=True
            )
            logging.info('RUN program_config: ' + str(prog_config))
            results.append(
                self.run_test_config(
                    model, params, prog_config, base_config, feed_data
                )
            )
            self.success_log('RUN_GPU_BASELINE done')

            for pred_config, (atol, rtol) in self.sample_predictor_configs(
                prog_config
            ):
                # skip info
                ignore_flag = False
                for ignore_info in self.ignore_cases:
                    if ignore_info[0](prog_config, pred_config):
                        ignore_flag = True
                        if (
                            ignore_info[1]
                            == IgnoreReasons.CUTLASS_ACCURACY_ERROR
                        ):
                            self.ignore_log(
                                "[CUTLASS_ACCURACY_ERROR] "
                                + ignore_info[2]
                                + ' '
                                + ' vs '
                                + self.inference_config_str(pred_config)
                            )
                        else:
                            raise NotImplementedError
                        break

                if os.path.exists(self.cache_dir):
                    shutil.rmtree(self.cache_dir)
                if not os.path.exists(self.cache_dir):
                    os.mkdir(self.cache_dir)

                try:
                    results.append(
                        self.run_test_config(
                            model, params, prog_config, pred_config, feed_data
                        )
                    )
                    self.assert_tensors_near(
                        atol, rtol, results[-1], results[0]
                    )
                except Exception as e:
                    self.fail_log(
                        self.inference_config_str(pred_config)
927
                        + f'\033[1;31m \nERROR INFO: {str(e)}\033[0m'
928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944
                    )
                    if not ignore_flag:
                        status = False
                    continue
                self.success_log(
                    'RUN predictor_config '
                    + self.inference_config_str(pred_config)
                    + ' done'
                )

        self.assertTrue(status)

    def inference_config_str(self, config) -> str:
        dic = {}
        enable_gpu = config.use_gpu()
        dic['use_gpu'] = enable_gpu
        return str(dic)