__init__.py 32.5 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define the functions to manipulate devices
16
import re
T
taixiurong 已提交
17
import os
18 19
import ctypes
import paddle
20 21
from paddle.fluid import core
from paddle.fluid import framework
22
from paddle.fluid.dygraph.parallel import ParallelEnv
23
from paddle.fluid.framework import is_compiled_with_cinn  # noqa: F401
24 25
from paddle.fluid.framework import is_compiled_with_cuda  # noqa: F401
from paddle.fluid.framework import is_compiled_with_rocm  # noqa: F401
26
from . import cuda
J
james 已提交
27
from . import xpu
28

29
__all__ = [  # noqa
30
    'get_cudnn_version',
31
    'set_device',
32 33
    'get_device',
    'XPUPlace',
J
jianghaicheng 已提交
34
    'IPUPlace',
35
    'MLUPlace',
W
Wenyu 已提交
36
    'is_compiled_with_xpu',
J
jianghaicheng 已提交
37
    'is_compiled_with_ipu',
38
    'is_compiled_with_cinn',
39
    'is_compiled_with_cuda',
40
    'is_compiled_with_rocm',
41
    'is_compiled_with_npu',
42
    'is_compiled_with_mlu',
43
    'is_compiled_with_custom_device',
44 45 46 47
    'get_all_device_type',
    'get_all_custom_device_type',
    'get_available_device',
    'get_available_custom_device',
48 49 50 51 52 53
    'Stream',
    'Event',
    'current_stream',
    'set_stream',
    'stream_guard',
    'synchronize',
54 55
]

56 57 58
_cudnn_version = None


59 60
# TODO: WITH_ASCEND_CL may changed to WITH_NPU or others in the future
# for consistent.
61 62
def is_compiled_with_npu():
    """
63
    Whether paddle was built with WITH_ASCEND_CL=ON to support Ascend NPU.
64

65 66
    Return:
        bool, ``True`` if NPU is supported, otherwise ``False``.
67 68 69 70 71

    Examples:
        .. code-block:: python

            import paddle
72
            support_npu = paddle.device.is_compiled_with_npu()
73 74 75 76
    """
    return core.is_compiled_with_npu()


77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
def is_compiled_with_custom_device(device_type):
    """
    Whether paddle was built with Paddle_CUSTOM_DEVICE .

    Args:
        std::string, the registered device type, like "npu".
    Return:
        bool, ``True`` if CustomDevice is supported, otherwise ``False``.

    Examples:
        .. code-block:: python

            import paddle
            support_npu = paddle.device.is_compiled_with_custom_device("npu")
    """
    return core.is_compiled_with_custom_device(device_type)


J
jianghaicheng 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
def is_compiled_with_ipu():
    """
    Whether paddle was built with WITH_IPU=ON to support Graphcore IPU.

    Returns (bool): `True` if IPU is supported, otherwise `False`.

    Examples:
        .. code-block:: python

            import paddle
            support_ipu = paddle.is_compiled_with_ipu()
    """
    return core.is_compiled_with_ipu()


def IPUPlace():
    """
    Return a Graphcore IPU Place

    Examples:
        .. code-block:: python

            # required: ipu

            import paddle
            place = paddle.device.IPUPlace()
    """
    return core.IPUPlace()


125 126 127 128 129 130 131 132 133 134
def is_compiled_with_xpu():
    """
    Whether paddle was built with WITH_XPU=ON to support Baidu Kunlun

    Returns (bool): whether paddle was built with WITH_XPU=ON

    Examples:
        .. code-block:: python

            import paddle
135
            support_xpu = paddle.device.is_compiled_with_xpu()
136 137 138 139 140 141 142 143 144 145 146 147 148
    """
    return core.is_compiled_with_xpu()


def XPUPlace(dev_id):
    """
    Return a Baidu Kunlun Place

    Parameters:
        dev_id(int): Baidu Kunlun device id

    Examples:
        .. code-block:: python
149

150
            # required: xpu
151

152
            import paddle
153
            place = paddle.device.XPUPlace(0)
154 155 156 157
    """
    return core.XPUPlace(dev_id)


158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
def is_compiled_with_mlu():
    """
    Whether paddle was built with WITH_MLU=ON to support Cambricon MLU

    Returns (bool): whether paddle was built with WITH_MLU=ON

    Examples:
        .. code-block:: python

            # required: mlu

            import paddle
            support_mlu = paddle.device.is_compiled_with_mlu()
    """
    return core.is_compiled_with_mlu()


def MLUPlace(dev_id):
    """
    Return a Cambricon MLU Place

    Parameters:
        dev_id(int): MLU device id

    Examples:
        .. code-block:: python

            # required: mlu

            import paddle
            place = paddle.device.MLUPlace(0)
    """
    return core.MLUPlace(dev_id)


193 194
def get_cudnn_version():
    """
195
    This funciton return the version of cudnn. the retuen value is int which represents the
196
    cudnn version. For example, if it return 7600, it represents the version of cudnn is 7.6.
197

198 199 200 201 202
    Returns:
        int: A int value which represents the cudnn version. If cudnn version is not installed, it return None.

    Examples:
        .. code-block:: python
203

204 205
            import paddle

206
            cudnn_version = paddle.device.get_cudnn_version()
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223



    """
    global _cudnn_version
    if not core.is_compiled_with_cuda():
        return None
    if _cudnn_version is None:
        cudnn_version = int(core.cudnn_version())
        _cudnn_version = cudnn_version
        if _cudnn_version < 0:
            return None
        else:
            return cudnn_version
    else:
        return _cudnn_version

224

C
chentianyu03 已提交
225
def _convert_to_place(device):
226
    lower_device = device.lower()
S
shentanyue 已提交
227 228 229 230 231 232 233
    if device in core.get_all_custom_device_type():
        selected_devices = os.getenv(
            "FLAGS_selected_{}s".format(device), "0"
        ).split(",")
        device_id = int(selected_devices[0])
        place = core.CustomPlace(device, device_id)
    elif lower_device == 'cpu':
234
        place = core.CPUPlace()
235 236
    elif lower_device == 'gpu':
        if not core.is_compiled_with_cuda():
237 238 239 240
            raise ValueError(
                "The device should not be 'gpu', "
                "since PaddlePaddle is not compiled with CUDA"
            )
241
        place = core.CUDAPlace(ParallelEnv().dev_id)
242 243
    elif lower_device == 'xpu':
        if not core.is_compiled_with_xpu():
244 245 246 247
            raise ValueError(
                "The device should not be 'xpu', "
                "since PaddlePaddle is not compiled with XPU"
            )
T
taixiurong 已提交
248 249 250
        selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
        device_id = int(selected_xpus[0])
        place = core.XPUPlace(device_id)
H
houj04 已提交
251 252
    elif lower_device == 'npu':
        if not core.is_compiled_with_npu():
253 254 255 256
            raise ValueError(
                "The device should not be 'npu', "
                "since PaddlePaddle is not compiled with NPU"
            )
H
houj04 已提交
257 258 259
        selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",")
        device_id = int(selected_npus[0])
        place = core.NPUPlace(device_id)
J
jianghaicheng 已提交
260 261 262
    elif lower_device == 'ipu':
        if not core.is_compiled_with_ipu():
            raise ValueError(
263 264 265
                "The device should not be 'ipu', "
                "since PaddlePaddle is not compiled with IPU"
            )
J
jianghaicheng 已提交
266
        place = core.IPUPlace()
267 268
    elif lower_device == 'mlu':
        if not core.is_compiled_with_mlu():
269 270 271 272
            raise ValueError(
                "The device should not be 'mlu', "
                "since PaddlePaddle is not compiled with MLU"
            )
273 274 275
        selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
        device_id = int(selected_mlus[0])
        place = core.MLUPlace(device_id)
276
    else:
277 278
        avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
        avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
H
houj04 已提交
279
        avaliable_npu_device = re.match(r'npu:\d+', lower_device)
280
        avaliable_mlu_device = re.match(r'mlu:\d+', lower_device)
281 282 283
        if avaliable_gpu_device:
            if not core.is_compiled_with_cuda():
                raise ValueError(
284
                    "The device should not be {}, since PaddlePaddle is "
285 286
                    "not compiled with CUDA".format(avaliable_gpu_device)
                )
287 288 289 290 291 292 293
            device_info_list = device.split(':', 1)
            device_id = device_info_list[1]
            device_id = int(device_id)
            place = core.CUDAPlace(device_id)
        if avaliable_xpu_device:
            if not core.is_compiled_with_xpu():
                raise ValueError(
294
                    "The device should not be {}, since PaddlePaddle is "
295 296
                    "not compiled with XPU".format(avaliable_xpu_device)
                )
297 298 299 300
            device_info_list = device.split(':', 1)
            device_id = device_info_list[1]
            device_id = int(device_id)
            place = core.XPUPlace(device_id)
H
houj04 已提交
301 302
        if avaliable_npu_device:
            if not core.is_compiled_with_npu():
S
shentanyue 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316
                device_info_list = device.split(':', 1)
                device_type = device_info_list[0]
                if device_type in core.get_all_custom_device_type():
                    device_id = device_info_list[1]
                    device_id = int(device_id)
                    place = core.CustomPlace(device_type, device_id)
                    return place
                else:
                    raise ValueError(
                        "The device should not be {}, since PaddlePaddle is "
                        "not compiled with NPU or compiled with custom device".format(
                            avaliable_npu_device
                        )
                    )
H
houj04 已提交
317 318 319 320
            device_info_list = device.split(':', 1)
            device_id = device_info_list[1]
            device_id = int(device_id)
            place = core.NPUPlace(device_id)
321 322 323 324
        if avaliable_mlu_device:
            if not core.is_compiled_with_mlu():
                raise ValueError(
                    "The device should not be {}, since PaddlePaddle is "
325 326
                    "not compiled with mlu".format(avaliable_mlu_device)
                )
327 328 329 330
            device_info_list = device.split(':', 1)
            device_id = device_info_list[1]
            device_id = int(device_id)
            place = core.MLUPlace(device_id)
S
shentanyue 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
        if (
            not avaliable_gpu_device
            and not avaliable_xpu_device
            and not avaliable_npu_device
            and not avaliable_mlu_device
        ):
            device_info_list = device.split(':', 1)
            device_type = device_info_list[0]
            if device_type in core.get_all_custom_device_type():
                device_id = device_info_list[1]
                device_id = int(device_id)
                place = core.CustomPlace(device_type, device_id)
            else:
                raise ValueError(
                    "The device must be a string which is like 'cpu', {}".format(
                        ', '.join(
                            "'{}', '{}:x'".format(x, x)
                            for x in ['gpu', 'xpu', 'npu', 'mlu']
                            + core.get_all_custom_device_type()
                        )
                    )
                )
C
chentianyu03 已提交
353
    return place
354

C
chentianyu03 已提交
355 356 357

def set_device(device):
    """
358
    Paddle supports running calculations on various types of devices, including CPU, GPU, XPU, NPU, MLU and IPU.
C
chentianyu03 已提交
359 360 361 362 363
    They are represented by string identifiers. This function can specify the global device
    which the OP will run.

    Parameters:
        device(str): This parameter determines the specific running device.
364 365
            It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``mlu``, ``gpu:x``, ``xpu:x``, ``npu:x``, ``mlu:x`` and ``ipu``,
            where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs.
C
chentianyu03 已提交
366 367 368 369

    Examples:

     .. code-block:: python
370

C
chentianyu03 已提交
371 372
        import paddle

373
        paddle.device.set_device("cpu")
C
chentianyu03 已提交
374 375 376 377 378
        x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32')
        x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32')
        data = paddle.stack([x1,x2], axis=1)
    """
    place = _convert_to_place(device)
379 380
    framework._set_expected_place(place)
    return place
381 382 383 384 385


def get_device():
    """
    This funciton can get the current global device of the program is running.
386
    It's a string which is like 'cpu', 'gpu:x', 'xpu:x', 'mlu:x' and 'npu:x'. if the global device is not
387
    set, it will return a string which is 'gpu:x' when cuda is avaliable or it
388 389 390 391 392
    will return a string which is 'cpu' when cuda is not avaliable.

    Examples:

     .. code-block:: python
393

394
        import paddle
395
        device = paddle.device.get_device()
396 397 398 399 400 401 402 403 404

    """
    device = ''
    place = framework._current_expected_place()
    if isinstance(place, core.CPUPlace):
        device = 'cpu'
    elif isinstance(place, core.CUDAPlace):
        device_id = place.get_device_id()
        device = 'gpu:' + str(device_id)
405 406 407
    elif isinstance(place, core.XPUPlace):
        device_id = place.get_device_id()
        device = 'xpu:' + str(device_id)
H
houj04 已提交
408 409 410
    elif isinstance(place, core.NPUPlace):
        device_id = place.get_device_id()
        device = 'npu:' + str(device_id)
J
jianghaicheng 已提交
411 412 413
    elif isinstance(place, core.IPUPlace):
        num_devices = core.get_ipu_device_count()
        device = "ipus:{{0-{}}}".format(num_devices - 1)
414 415 416
    elif isinstance(place, core.MLUPlace):
        device_id = place.get_device_id()
        device = 'mlu:' + str(device_id)
417 418 419 420
    elif isinstance(place, core.CustomPlace):
        device_id = place.get_device_id()
        device_type = place.get_device_type()
        device = device_type + ':' + str(device_id)
J
jianghaicheng 已提交
421 422
    else:
        raise ValueError("The device specification {} is invalid".format(place))
423 424

    return device
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458


def get_all_device_type():
    """
    Get all available device types.

    Returns:
        A list of all available device types.

    Examples:
        .. code-block:: python

            import paddle
            paddle.device.get_all_device_type()

            # Case 1: paddlepaddle-cpu package installed, and no custom device registerd.
            # Output: ['cpu']

            # Case 2: paddlepaddle-gpu package installed, and no custom device registerd.
            # Output: ['cpu', 'gpu']

            # Case 3: paddlepaddle-cpu package installed, and custom deivce 'CustomCPU' is registerd.
            # Output: ['cpu', 'CustomCPU']

            # Case 4: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd.
            # Output: ['cpu', 'gpu', 'CustomCPU', 'CustomGPU']
    """
    return core.get_all_device_type()


def get_all_custom_device_type():
    """
    Get all available custom device types.

459
    Returns:
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
        A list of all available custom device types.

    Examples:
        .. code-block:: python

            import paddle
            paddle.device.get_all_custom_device_type()

            # Case 1: paddlepaddle-gpu package installed, and no custom device registerd.
            # Output: None

            # Case 2: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd.
            # Output: ['CustomCPU', 'CustomGPU']
    """
    return core.get_all_custom_device_type()


def get_available_device():
    """
    Get all available devices.

    Returns:
        A list of all available devices.

    Examples:
        .. code-block:: python

            import paddle
            paddle.device.get_available_device()

            # Case 1: paddlepaddle-cpu package installed, and no custom device registerd.
            # Output: ['cpu']

            # Case 2: paddlepaddle-gpu package installed, and no custom device registerd.
            # Output: ['cpu', 'gpu:0', 'gpu:1']

            # Case 3: paddlepaddle-cpu package installed, and custom deivce 'CustomCPU' is registerd.
            # Output: ['cpu', 'CustomCPU']

            # Case 4: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd.
            # Output: ['cpu', 'gpu:0', 'gpu:1', 'CustomCPU', 'CustomGPU:0', 'CustomGPU:1']
    """
    return core.get_available_device()


def get_available_custom_device():
    """
    Get all available custom devices.

    Returns:
       A list of all available custom devices.

    Examples:
        .. code-block:: python

            import paddle
            paddle.device.get_available_custom_device()

            # Case 1: paddlepaddle-gpu package installed, and no custom device registerd.
            # Output: None

            # Case 2: paddlepaddle-gpu package installed, and custom deivce 'CustomCPU' and 'CustomGPU' is registerd.
            # Output: ['CustomCPU', 'CustomGPU:0', 'CustomGPU:1']
    """
    return core.get_available_custom_device()
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996


class Event(object):
    '''
    A device event wrapper around StreamBase.
    Parameters:
        device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): Which device the stream runn on. If device is None, the device is the current device. Default: None.
            It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
            where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
        enable_timing (bool, optional): indicates if the event should measure time, default is False
        blocking (bool, optional): if True, ``wait`` will be blocking, default is False
        interprocess (bool): if True, the event can be shared between processes, default is False
    Returns:
        Event: The event.
    Examples:
        .. code-block:: python
            # required: custom_device
            import paddle
            e1 = paddle.device.Event()
            e2 = paddle.device.Event('custom_cpu')
            e3 = paddle.device.Event('custom_cpu:0')
            e4 = paddle.device.Event(paddle.CustomPlace('custom_cpu', 0))
    '''

    def __init__(
        self,
        device=None,
        enable_timing=False,
        blocking=False,
        interprocess=False,
    ):
        if device is None:
            self.device = paddle.framework._current_expected_place()
        elif isinstance(device, str):
            self.device = paddle.device._convert_to_place(device)
        else:
            self.device = device

        if paddle.is_compiled_with_cuda() and isinstance(
            self.device, paddle.CUDAPlace
        ):
            self.event_base = core.CUDAEvent(
                enable_timing, blocking, interprocess
            )
        elif isinstance(self.device, paddle.CustomPlace):
            self.event_base = core.CustomDeviceEvent(
                self.device.get_device_type(),
                self.device.get_device_id(),
                enable_timing,
                blocking,
                interprocess,
            )
        else:
            raise TypeError(
                "device should be gpu, xpu, {}".format(
                    ",".join(paddle.device.get_all_custom_device_type())
                )
            )

    def record(self, stream=None):
        '''
        Records the event in a given stream.
        Parameters:
            stream(Stream, optional): The given stream. By default, stream is None,
            event will be recorded in current_stream.
        Returns:
            None.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                e = paddle.device.Event()
                e.record()

                s = paddle.device.Stream()
                e.record(s)
        '''
        if stream is None:
            stream = current_stream(self.device)

        self.event_base.record(stream.stream_base)

    def query(self):
        '''
        Checks if all work currently captured by event has completed.
        Returns:
            bool: Whether all work currently captured by event has completed.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                e = paddle.device.Event()
                e.query()
        '''
        return self.event_base.query()

    def elapsed_time(self, end_event):
        '''
        Returns the time elapsed in milliseconds after the event was
        recorded and before the end_event was recorded.
        Returns:
            int: The time.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                e1 = paddle.device.Event()
                e2 = paddle.device.Event()
                e1.elapsed_time(e2)
        '''
        return 0

    def synchronize(self):
        '''
        Waits for the event to complete.
        Waits until the completion of all work currently captured in this event.
        This prevents the CPU thread from proceeding until the event completes.
        Returns:
            None.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                e = paddle.device.Event()
                e.synchronize()
        '''
        self.event_base.synchronize()

    def __repr__(self):
        return self.event_base


class Stream(object):
    '''
    A device stream wrapper around StreamBase.
    Parameters:
        device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): Which device the stream runn on. If device is None, the device is the current device. Default: None.
            It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
            where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
        priority(int, optional): priority of the CUDA stream. Can be either
            1 (high priority) or 2 (low priority). By default, streams have
            priority 2.
    Returns:
        Stream: The stream.
    Examples:
        .. code-block:: python
            # required: custom_device
            import paddle
            s1 = paddle.device.Stream()
            s2 = paddle.device.Stream('custom_cpu')
            s3 = paddle.device.Stream('custom_cpu:0')
            s4 = paddle.device.Stream(paddle.CustomPlace('custom_cpu', 0))
    '''

    def __init__(self, device=None, priority=2, stream_base=None):
        if stream_base is not None:
            if isinstance(
                stream_base, (core.CUDAStream, core.CustomDeviceStream)
            ):
                self.stream_base = stream_base
                self.device = stream_base.place
            else:
                raise TypeError(
                    "stream_base should be CUDAStream, CustomDeviceStream"
                )
            return

        if device is None:
            self.device = paddle.framework._current_expected_place()
        elif isinstance(device, str):
            self.device = paddle.device._convert_to_place(device)
        else:
            self.device = device

        if paddle.is_compiled_with_cuda() and isinstance(
            self.device, paddle.CUDAPlace
        ):
            self.stream_base = core.CUDAStream(
                self.device.get_device_id(), priority
            )
        elif isinstance(self.device, paddle.CustomPlace):
            self.stream_base = core.CustomDeviceStream(
                self.device.get_device_type(),
                self.device.get_device_id(),
                priority,
                blocking=False,
            )
        else:
            raise TypeError(
                "device should be gpu, xpu, {}".format(
                    ",".join(paddle.device.get_all_custom_device_type())
                )
            )

    def wait_event(self, event):
        '''
        Makes all future work submitted to the stream wait for an event.
        Parameters:
            event (Event): an event to wait for.
        Returns:
            None.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                s = paddle.device.Stream()
                e = paddle.device.Event()
                s.wait_event(e)
        '''
        self.stream_base.wait_event(event.event_base)

    def wait_stream(self, stream):
        '''
        Synchronizes with another stream.
        All future work submitted to this stream will wait until all kernels
        submitted to a given stream at the time of call complete.
        Parameters:
            stream (Stream): a stream to synchronize.
        Returns:
            None.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                s1 = paddle.device.Stream()
                s2 = paddle.device.Stream()
                s1.wait_stream(s2)
        '''
        self.stream_base.wait_stream(stream.stream_base)

    def record_event(self, event=None):
        '''
        Records an event.
        Parameters:
            event (Event, optional): event to record. If not given, a new one
                will be allocated.
        Returns:
            Event: Recorded event.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                s = paddle.device.Stream()
                e1 = s.record_event()

                e2 = paddle.device.Event()
                s.record_event(e2)
        '''
        if event is None:
            event = Event(self.device)
        event.record(self)
        return event

    def query(self):
        '''
        Checks if all the work submitted has been completed.
        Returns:
            bool: Whether all kernels in this stream are completed.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                s = paddle.device.Stream()
                s.query()
        '''
        return self.stream_base.query()

    def synchronize(self):
        '''
        Wait for all the kernels in this stream to complete.
        Returns:
            None.
        Examples:
            .. code-block:: python
                # required: custom_device
                import paddle
                s = paddle.device.Stream()
                s.synchronize()
        '''
        self.stream_base.synchronize()

    @property
    def _as_parameter_(self):
        if isinstance(self.stream_base, core.CUDAStream):
            return ctypes.c_void_p(self.stream_base.cuda_stream)
        else:
            return ctypes.c_void_p(self.stream_base.raw_stream)

    def __eq__(self, o):
        if isinstance(o, Stream):
            return super(Stream, self).__eq__(o)
        return False

    def __hash__(self):
        return hash((self.stream_base, self.device))

    def __repr__(self):
        return '<paddle.device.Stream device={0} stream={1:#x}>'.format(
            self.device, self._as_parameter_.value
        )


def current_stream(device=None):
    '''
    Return the current stream by the device.
    Parameters:
        device(str|paddle.CUDAPlace(n)|paddle.CustomPlace(n)): The device which want to get stream from.  If device is None, the device is the current device. Default: None.
            It can be ``gpu``, ``gpu:x``,``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
            where ``x`` is the index of the GPUs, CustomDevicecs. And it can be paddle.CUDAPlace(n) or paddle.CustomPlace(n).
    Returns:
        Stream: The stream to the device.
    Examples:
        .. code-block:: python
            # required: custom_device
            import paddle
            s1 = paddle.device.current_stream()
            s2 = paddle.device.current_stream("gpu:0")
            place = paddle.CustomPlace('custom_cpu', 0)
            s3 = paddle.device.current_stream(place)
    '''
    if device is None:
        place = paddle.framework._current_expected_place()
    elif isinstance(device, str):
        place = paddle.device._convert_to_place(device)
    else:
        place = device

    if paddle.is_compiled_with_cuda() and isinstance(place, paddle.CUDAPlace):
        return Stream(
            stream_base=core._get_current_stream(place.get_device_id())
        )
    elif isinstance(place, paddle.CustomPlace):
        return Stream(
            stream_base=core._get_current_custom_device_stream(
                place.get_device_type(), place.get_device_id()
            )
        )
    else:
        raise TypeError(
            "device should be gpu, xpu, {}".format(
                ",".join(paddle.device.get_all_custom_device_type())
            )
        )


def set_stream(stream):
    '''
    Set the current stream.
    Parameters:
        stream(Stream): The selected stream.
    Returns:
        Stream: The previous stream.
    Examples:
        .. code-block:: python
            # required: custom_device
            import paddle
            s = paddle.device.Stream()
            paddle.device.set_stream(s)
    '''

    prev_stream = current_stream(stream.stream_base.place)

    if paddle.is_compiled_with_cuda() and isinstance(
        stream.stream_base.place, paddle.CUDAPlace
    ):
        core._set_current_stream(stream.stream_base)
    elif isinstance(stream.stream_base.place, paddle.CustomPlace):
        core._set_current_custom_device_stream(
            stream.stream_base.place.get_device_type(),
            stream.stream_base.place.get_device_id(),
            stream.stream_base,
        )
    else:
        raise TypeError(
            "device should be gpu, xpu, {}".format(
                ",".join(paddle.device.get_all_custom_device_type())
            )
        )

    return prev_stream


class stream_guard(object):
    '''
    Notes:
        This API only supports dynamic graph mode currently.
    A context manager that specifies the current stream context by the given stream.
    Parameters:
        stream(Stream, optional): the selected stream. If stream is None, just yield.
    Returns:
        None.
    Examples:
        .. code-block:: python
            # required: custom_device
            import paddle
            s = paddle.device.Stream()
            data1 = paddle.ones(shape=[20])
            data2 = paddle.ones(shape=[20])
            data3 = data1 + data2
            with paddle.device.stream_guard(s):
                s.wait_stream(paddle.device.default_stream())
                data4 = data1 + data3
    '''

    def __init__(self, stream=None):
        self.stream = stream

    def __enter__(self):
        cur_stream = self.stream
        if cur_stream is None:
            return

        self.src_prev_stream = current_stream(cur_stream.device)
        if self.src_prev_stream.device != cur_stream.device:
            self.tmp_place = paddle.fluid.framework._current_expected_place()
            paddle.fluid.framework._set_expected_place(cur_stream.device)
            self.dst_prev_stream = current_stream(cur_stream.device)
            set_stream(cur_stream)
        else:
            set_stream(cur_stream)

    def __exit__(self, *args):
        cur_stream = self.stream
        if cur_stream is None:
            return

        if self.src_prev_stream.device != cur_stream.device:
            set_stream(self.dst_prev_stream)
            paddle.fluid.framework._set_expected_place(self.tmp_place)
            set_stream(self.src_prev_stream)
        else:
            set_stream(self.src_prev_stream)


def synchronize(device=None):
    '''
    Wait for the compute on the given device to finish.
    Parameters:
        device(str|paddle.CUDAPlace(n)|paddle.XPUPlace(n)|paddle.CustomPlace(n)): The device which want to wait for.  If device is None, the device is the current device. Default: None.
            It can be ``gpu``, ``gpu:x``, ``xpu``, ``xpu:x``, ``custom_device``, ``custom_device:x``, where ``custom_device`` is the name of CustomDevicec,
            where ``x`` is the index of the GPUs, XPUs, NPUs or MLUs. And it can be paddle.CUDAPlace(n) or paddle.XPUPlace(n) or paddle.CustomPlace(n).
    Examples:
        .. code-block:: python
            # required: custom_device
            import paddle
            paddle.device.synchronize()
            paddle.device.synchronize("gpu:0")
            place = paddle.CustomPlace('custom_cpu', 0)
            paddle.device.synchronize(place)
    '''

    if device is None:
        place = paddle.framework._current_expected_place()
    elif isinstance(device, str):
        place = paddle.device._convert_to_place(device)
    else:
        place = device

    if paddle.is_compiled_with_cuda() and isinstance(place, paddle.CUDAPlace):
        core._device_synchronize(place.get_device_id())
    elif paddle.is_compiled_with_xpu() and isinstance(place, paddle.XPUPlace):
        core._xpu_device_synchronize(place.get_device_id())
    elif isinstance(place, paddle.CustomPlace):
        core._synchronize_custom_device(
            place.get_device_type(), place.get_device_id()
        )
    else:
        raise TypeError(
            "device should be gpu, xpu, {}".format(
                ",".join(paddle.device.get_all_custom_device_type())
            )
        )