core.py 18.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import site
16 17
import sys
import os
18 19
import warnings
import platform
20
import logging
21

22 23 24 25
has_paddle_dy_lib = False

dy_lib_name = 'libpaddle'
dy_lib_suffix = 'so'
26
if os.name == 'nt':
27
    dy_lib_suffix = 'pyd'
28 29

current_path = os.path.abspath(os.path.dirname(__file__))
30 31
if os.path.exists(current_path + os.sep + dy_lib_name + '.' + dy_lib_suffix):
    has_paddle_dy_lib = True
32

33 34
try:
    if os.name == 'nt':
35
        third_lib_path = current_path + os.sep + '..' + os.sep + 'libs'
36
        # Will load shared library from 'path' on windows
37 38 39
        os.environ['path'] = (
            current_path + ';' + third_lib_path + ';' + os.environ['path']
        )
40
        sys.path.insert(0, third_lib_path)
41 42 43 44 45
        # Note: from python3.8, PATH will not take effect
        # https://github.com/python/cpython/pull/12302
        # Use add_dll_directory to specify dll resolution path
        if sys.version_info[:2] >= (3, 8):
            os.add_dll_directory(third_lib_path)
46 47 48 49

except ImportError as e:
    if os.name == 'nt':
        executable_path = os.path.abspath(os.path.dirname(sys.executable))
50 51
        raise ImportError(
            """NOTE: You may need to run \"set PATH=%s;%%PATH%%\"
52 53
        if you encounters \"DLL load failed\" errors. If you have python
        installed in other directory, replace \"%s\" with your own
54 55 56
        directory. The original error is: \n %s"""
            % (executable_path, executable_path, str(e))
        )
57 58 59 60 61
    else:
        raise ImportError(
            """NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\"
        if you encounters \"libmkldnn.so not found\" errors. If you have python
        installed in other directory, replace \"/usr/local/lib\" with your own
62 63 64
        directory. The original error is: \n"""
            + str(e)
        )
65 66 67
except Exception as e:
    raise e

T
tensor-tang 已提交
68

T
tensor-tang 已提交
69 70 71 72 73 74 75 76
def avx_supported():
    """
    Whether current system(Linux, MacOS, Windows) is supported with AVX.
    """
    sysstr = platform.system().lower()
    has_avx = False
    if sysstr == 'linux':
        try:
Z
zlsh80826 已提交
77 78 79
            pipe = os.popen('cat /proc/cpuinfo | grep -i avx')
            has_avx = pipe.read() != ''
            pipe.close()
T
tensor-tang 已提交
80
        except Exception as e:
81 82 83 84
            sys.stderr.write(
                'Can not get the AVX flag from /proc/cpuinfo.\n'
                'The original error is: %s\n' % str(e)
            )
T
tensor-tang 已提交
85 86
        return has_avx
    elif sysstr == 'darwin':
T
tensor-tang 已提交
87
        try:
Z
zlsh80826 已提交
88 89 90
            pipe = os.popen('sysctl machdep.cpu.features | grep -i avx')
            has_avx = pipe.read() != ''
            pipe.close()
T
tensor-tang 已提交
91 92
        except Exception as e:
            sys.stderr.write(
T
tensor-tang 已提交
93
                'Can not get the AVX flag from machdep.cpu.features.\n'
94 95
                'The original error is: %s\n' % str(e)
            )
T
tensor-tang 已提交
96
        if not has_avx:
97
            import subprocess
98

99 100 101 102
            pipe = subprocess.Popen(
                'sysctl machdep.cpu.leaf7_features | grep -i avx',
                shell=True,
                stdout=subprocess.PIPE,
103 104
                stderr=subprocess.PIPE,
            )
105 106
            _ = pipe.communicate()
            has_avx = True if pipe.returncode == 0 else False
T
tensor-tang 已提交
107 108 109
        return has_avx
    elif sysstr == 'windows':
        import ctypes
110

T
tensor-tang 已提交
111
        ONE_PAGE = ctypes.c_size_t(0x1000)
T
tensor-tang 已提交
112

T
tensor-tang 已提交
113 114 115 116 117 118 119
        def asm_func(code_str, restype=ctypes.c_uint32, argtypes=()):
            # Call the code_str as a function
            # Alloc 1 page to ensure the protection
            pfnVirtualAlloc = ctypes.windll.kernel32.VirtualAlloc
            pfnVirtualAlloc.restype = ctypes.c_void_p
            MEM_COMMIT = ctypes.c_ulong(0x1000)
            PAGE_READWRITE = ctypes.c_ulong(0x4)
120 121 122
            address = pfnVirtualAlloc(
                None, ONE_PAGE, MEM_COMMIT, PAGE_READWRITE
            )
T
tensor-tang 已提交
123 124 125 126
            if not address:
                raise Exception("Failed to VirtualAlloc")

            # Copy the code into the memory segment
127 128 129 130 131 132
            memmove = ctypes.CFUNCTYPE(
                ctypes.c_void_p,
                ctypes.c_void_p,
                ctypes.c_void_p,
                ctypes.c_size_t,
            )(ctypes._memmove_addr)
T
tensor-tang 已提交
133 134 135 136 137 138
            if memmove(address, code_str, len(code_str)) < 0:
                raise Exception("Failed to memmove")

            # Enable execute permissions
            PAGE_EXECUTE = ctypes.c_ulong(0x10)
            pfnVirtualProtect = ctypes.windll.kernel32.VirtualProtect
139 140 141 142 143 144
            res = pfnVirtualProtect(
                ctypes.c_void_p(address),
                ONE_PAGE,
                PAGE_EXECUTE,
                ctypes.byref(ctypes.c_ulong(0)),
            )
T
tensor-tang 已提交
145 146 147 148 149 150 151 152
            if not res:
                raise Exception("Failed VirtualProtect")

            # Flush instruction cache
            pfnGetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
            pfnGetCurrentProcess.restype = ctypes.c_void_p
            prochandle = ctypes.c_void_p(pfnGetCurrentProcess())
            res = ctypes.windll.kernel32.FlushInstructionCache(
153 154
                prochandle, ctypes.c_void_p(address), ONE_PAGE
            )
T
tensor-tang 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
            if not res:
                raise Exception("Failed FlushInstructionCache")

            # Cast the memory to function
            functype = ctypes.CFUNCTYPE(restype, *argtypes)
            func = functype(address)
            return func, address

        # http://en.wikipedia.org/wiki/CPUID#EAX.3D1:_Processor_Info_and_Feature_Bits
        # mov eax,0x1; cpuid; mov cx, ax; ret
        code_str = b"\xB8\x01\x00\x00\x00\x0f\xa2\x89\xC8\xC3"
        avx_bit = 28
        retval = 0
        try:
            # Convert the code_str into a function that returns uint
            func, address = asm_func(code_str)
            retval = func()
172 173 174
            ctypes.windll.kernel32.VirtualFree(
                ctypes.c_void_p(address), ctypes.c_size_t(0), ONE_PAGE
            )
T
tensor-tang 已提交
175
        except Exception as e:
176 177 178 179
            sys.stderr.write(
                'Failed getting the AVX flag on Windows.\n'
                'The original error is: %s\n' % str(e)
            )
T
tensor-tang 已提交
180 181 182 183 184 185
        return (retval & (1 << avx_bit)) > 0
    else:
        sys.stderr.write('Do not get AVX flag on %s\n' % sysstr)
        return False


186 187
def run_shell_command(cmd):
    import subprocess
188 189 190 191

    out, err = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
    ).communicate()
192 193 194
    if err:
        return None
    else:
195
        return out.decode('utf-8').strip()
196 197 198 199


def get_dso_path(core_so, dso_name):
    if core_so and dso_name:
200 201 202
        return run_shell_command(
            "ldd %s|grep %s|awk '{print $3}'" % (core_so, dso_name)
        )
203 204 205 206 207 208 209 210
    else:
        return None


def load_dso(dso_absolute_path):
    if dso_absolute_path:
        try:
            from ctypes import cdll
211

212 213 214 215 216 217
            cdll.LoadLibrary(dso_absolute_path)
        except:
            warnings.warn("Load {} failed".format(dso_absolute_path))


def pre_load(dso_name):
218 219
    if has_paddle_dy_lib:
        core_so = current_path + os.sep + dy_lib_name + '.' + dy_lib_suffix
220 221 222 223 224 225
    else:
        core_so = None
    dso_path = get_dso_path(core_so, dso_name)
    load_dso(dso_path)


chen.zhiyu's avatar
chen.zhiyu 已提交
226 227 228 229 230 231 232 233 234
def get_libc_ver():
    ldd_glibc = run_shell_command("ldd --version | awk '/ldd/{print $NF}'")
    if ldd_glibc is not None:
        return ("glibc", ldd_glibc)

    ldd_musl = run_shell_command("ldd 2>&1 | awk '/Version/{print $NF}'")
    if ldd_musl is not None:
        return ("musl", ldd_musl)
    return (None, None)
235 236 237


def less_than_ver(a, b):
238 239 240
    if a is None or b is None:
        return False

241 242 243 244
    import re
    import operator

    def to_list(s):
245
        s = re.sub(r'(\.0+)+$', '', s)
246 247 248 249 250
        return [int(x) for x in s.split('.')]

    return operator.lt(to_list(a), to_list(b))


251
# NOTE(zhiqiu): An error may occurs when import paddle in linux platform with glibc < 2.22,
252 253 254 255 256
# the error message of which is "dlopen: cannot load any more object with static TLS".
# This happens when:
# (1) the number of dynamic shared librarys (DSO) loaded > 14,
# (2) after that, load a dynamic shared library (DSO) with static TLS.
# For paddle, the problem is that 'libgomp' is a DSO with static TLS, and it is loaded after 14 DSOs.
257
# So, here is a tricky way to solve the problem by pre load 'libgomp' before 'libpaddle.so'.
258
# The final solution is to upgrade glibc to > 2.22 on the target system.
chen.zhiyu's avatar
chen.zhiyu 已提交
259 260 261 262 263 264
if platform.system().lower() == 'linux':
    libc_type, libc_ver = get_libc_ver()
    if libc_type == 'glibc' and less_than_ver(libc_ver, '2.23'):
        try:
            pre_load('libgomp')
        except Exception as e:
265
            # NOTE(zhiqiu): do not abort if failed, since it may success when import libpaddle.so
chen.zhiyu's avatar
chen.zhiyu 已提交
266
            sys.stderr.write('Error: Can not preload libgomp.so')
267

268 269
try:
    from . import libpaddle
270

271 272 273 274 275
    if avx_supported() and not libpaddle.is_compiled_with_avx():
        sys.stderr.write(
            "Hint: Your machine support AVX, but the installed paddlepaddle doesn't have avx core. "
            "Hence, no-avx core with worse preformance will be imported.\nIf you like, you could "
            "reinstall paddlepaddle by 'python -m pip install --force-reinstall paddlepaddle-gpu[==version]' "
276 277
            "to get better performance.\n"
        )
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303

    # assign tensor alias
    libpaddle.LoDTensor = libpaddle.Tensor

    from .libpaddle import *
    from .libpaddle import __doc__, __file__, __name__, __package__
    from .libpaddle import __unittest_throw_exception__
    from .libpaddle import _append_python_callable_object_and_return_id
    from .libpaddle import _cleanup, _Scope
    from .libpaddle import _get_use_default_grad_op_desc_maker_ops
    from .libpaddle import _get_all_register_op_kernels
    from .libpaddle import _is_program_version_supported
    from .libpaddle import _set_eager_deletion_mode
    from .libpaddle import _get_eager_deletion_vars
    from .libpaddle import _set_fuse_parameter_group_size
    from .libpaddle import _set_fuse_parameter_memory_size
    from .libpaddle import _is_dygraph_debug_enabled
    from .libpaddle import _dygraph_debug_level
    from .libpaddle import _switch_tracer
    from .libpaddle import _set_paddle_lib_path
    from .libpaddle import _create_loaded_parameter
    from .libpaddle import _cuda_synchronize
    from .libpaddle import _is_compiled_with_heterps
    from .libpaddle import _promote_types_if_complex_exists
    from .libpaddle import _set_cached_executor_build_strategy
    from .libpaddle import _device_synchronize
J
james 已提交
304
    from .libpaddle import _xpu_device_synchronize
305 306 307
    from .libpaddle import _get_current_stream
    from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
    from .libpaddle import _set_current_stream
308
    from .libpaddle import _get_phi_kernel_name
309 310
    from .libpaddle import _add_skip_comp_ops
    from .libpaddle import _remove_skip_comp_ops
311 312 313 314 315 316 317

    # prim controller flags
    from .libpaddle import __set_bwd_prim_enabled
    from .libpaddle import _is_bwd_prim_enabled
    from .libpaddle import __set_fwd_prim_enabled
    from .libpaddle import _is_fwd_prim_enabled
    from .libpaddle import __set_all_prim_enabled
318
    from .libpaddle import _set_prim_target_grad_name
319

320 321 322 323 324 325 326
    # custom devivce
    from .libpaddle import _get_current_custom_device_stream
    from .libpaddle import _set_current_custom_device_stream
    from .libpaddle import _synchronize_custom_device
    from .libpaddle import CustomDeviceStream
    from .libpaddle import CustomDeviceEvent

327 328 329 330 331 332 333 334 335
    if sys.platform != 'win32':
        from .libpaddle import _set_process_pids
        from .libpaddle import _erase_process_pids
        from .libpaddle import _set_process_signal_handler
        from .libpaddle import _throw_error_if_process_failed
        from .libpaddle import _convert_to_tensor_list
        from .libpaddle import _array_to_share_memory_tensor
        from .libpaddle import _cleanup_mmap_fds
        from .libpaddle import _remove_tensor_list_mmap_fds
336
        from .libpaddle import _set_max_memory_map_allocation_pool_size
337
except Exception as e:
338
    if has_paddle_dy_lib:
339
        sys.stderr.write(
340 341 342 343 344 345 346
            'Error: Can not import paddle core while this file exists: '
            + current_path
            + os.sep
            + 'libpaddle.'
            + dy_lib_suffix
            + '\n'
        )
347 348 349
    if not avx_supported() and libpaddle.is_compiled_with_avx():
        sys.stderr.write(
            "Error: Your machine doesn't support AVX, but the installed PaddlePaddle is avx core, "
350 351
            "you should reinstall paddlepaddle with no-avx core.\n"
        )
352
    raise e
353 354


355 356 357 358 359 360 361 362 363 364 365
def set_paddle_custom_device_lib_path(lib_path):
    if os.environ.get('CUSTOM_DEVICE_ROOT', None) is not None:
        # use setted environment value
        return
    if os.path.exists(lib_path):
        # set CUSTOM_DEVICE_ROOT default path
        os.environ['CUSTOM_DEVICE_ROOT'] = os.path.normpath(lib_path)
    else:
        os.environ['CUSTOM_DEVICE_ROOT'] = ''


366 367
# set paddle lib path
def set_paddle_lib_path():
368 369 370 371 372
    site_dirs = (
        site.getsitepackages()
        if hasattr(site, 'getsitepackages')
        else [x for x in sys.path if 'site-packages' in x]
    )
373 374 375 376
    for site_dir in site_dirs:
        lib_dir = os.path.sep.join([site_dir, 'paddle', 'libs'])
        if os.path.exists(lib_dir):
            _set_paddle_lib_path(lib_dir)
377
            set_paddle_custom_device_lib_path(
378 379
                os.path.sep.join([lib_dir, '..', '..', 'paddle-plugins'])
            )
380 381 382 383 384
            return
    if hasattr(site, 'USER_SITE'):
        lib_dir = os.path.sep.join([site.USER_SITE, 'paddle', 'libs'])
        if os.path.exists(lib_dir):
            _set_paddle_lib_path(lib_dir)
385
            set_paddle_custom_device_lib_path(
386 387
                os.path.sep.join([lib_dir, '..', '..', 'paddle-plugins'])
            )
388 389 390


set_paddle_lib_path()
391

392 393 394 395 396 397 398 399 400 401 402 403
# We have 3 FLAGS to judge whether prim is enabled
# FLAGS_prim_forward: Open or close forward prim strategy
# FLAGS_prim_backward: Open or close backward prim strategy
# FLAGS_prim_all: Open or close all prim strategy
#
#
# Priorities:
# if With CINN and Dy2St:
# # # _set_prim_all_enabled > FLAGS_prim_all > check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
# else:
# # # _set_prim_all_enabled > FLAGS_prim_all == check_and_set_prim_all_enabled == _set_prim_backward_enabled == _set_prim_backward_enabled > FLAGS_prim_forward == FLAGS_prim_backward
def __sync_stat_with_flag(flag):
404
    if flag == "FLAGS_prim_forward":
405 406 407 408 409 410 411 412 413
        flag_value = os.getenv("FLAGS_prim_forward")
        assert flag_value is not None
        flag_value = flag_value.lower()
        if flag_value == "false":
            __set_fwd_prim_enabled(False)
        elif flag_value == "true":
            __set_fwd_prim_enabled(True)
        else:
            raise TypeError(f"flag {flag} should be true or false.")
414
        print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
415
    elif flag == "FLAGS_prim_backward":
416 417 418 419 420 421 422 423 424
        flag_value = os.getenv("FLAGS_prim_backward")
        assert flag_value is not None
        flag_value = flag_value.lower()
        if flag_value == "false":
            __set_bwd_prim_enabled(False)
        elif flag_value == "true":
            __set_bwd_prim_enabled(True)
        else:
            raise TypeError(f"flag {flag} should be true or false.")
425
        print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
426
    elif flag == "FLAGS_prim_all":
427 428 429 430 431 432 433 434 435
        flag_value = os.getenv("FLAGS_prim_all")
        assert flag_value is not None
        flag_value = flag_value.lower()
        if flag_value == "false":
            __set_all_prim_enabled(False)
        elif flag_value == "true":
            __set_all_prim_enabled(True)
        else:
            raise TypeError(f"flag {flag} should be true or false.")
436
        print(
437 438 439 440 441 442 443
            "all prim enabled: ",
            bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
        )
    else:
        raise TypeError(
            f"We only support FLAGS_prim_forward/FLAGS_prim_backward/FLAGS_prim_all but we got {flag}."
        )
444 445


446 447 448 449 450
# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors.
def _test_use_sync(value):
    __sync_stat_with_flag(value)


451
# ops in forward_blacklisk will not be replaced by composite ops.
C
cyber-pioneer 已提交
452
prim_config = {"forward_blacklist": set(), "composite_ops_record": set()}
453 454 455 456 457 458


def _set_prim_forward_blacklist(ops=None):
    if ops is None:
        prim_config["forward_blacklist"] = []
    elif isinstance(ops, str):
C
cyber-pioneer 已提交
459
        prim_config["forward_blacklist"].add(ops)
460 461 462 463 464 465 466
    elif isinstance(ops, (list, tuple)):
        for item in ops:
            if not isinstance(item, str):
                raise TypeError(
                    "ops set in forward_blacklist must belong to [str, str of tuple or list]"
                )
            else:
C
cyber-pioneer 已提交
467
                prim_config["forward_blacklist"].add(item)
468 469 470 471 472 473 474
    else:
        raise TypeError(
            "ops set in forward_blacklist must belong to [str, str of tuple or list]"
        )
    return


475 476
def _set_prim_backward_enabled(value):
    __set_bwd_prim_enabled(bool(value))
477
    print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
478 479


480 481
def _set_prim_forward_enabled(value):
    __set_fwd_prim_enabled(bool(value))
482
    print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
483 484


485 486
def _set_prim_all_enabled(value):
    __set_all_prim_enabled(bool(value))
487
    print(
488 489 490
        "all prim enabled: ",
        bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
    )
491

492 493 494 495

def __sync_prim_backward_status():
    flag_value = os.getenv("FLAGS_prim_backward")
    if flag_value is None:
496
        print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
497 498 499 500 501 502 503
    else:
        __sync_stat_with_flag("FLAGS_prim_backward")


def __sync_prim_forward_status():
    flag_value = os.getenv("FLAGS_prim_forward")
    if flag_value is None:
504
        print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
505 506 507 508 509 510 511 512 513 514 515
    else:
        __sync_stat_with_flag("FLAGS_prim_forward")


def check_and_set_prim_all_enabled():
    flag_value = os.getenv("FLAGS_prim_all")
    if flag_value is None:
        __sync_prim_backward_status()
        __sync_prim_forward_status()
    else:
        __sync_stat_with_flag("FLAGS_prim_all")