diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a59e7a514ab2a887b345f132765f042c8292d919..3447e68bda615cdb1fe9c71ef7cc15af2ed07b94 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1497,6 +1497,10 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_brpc", IsCompiledWithBrpc); m.def("is_compiled_with_dist", IsCompiledWithDIST); + m.def("_cuda_synchronize", [](const platform::CUDAPlace &place) { + platform::DeviceContextPool::Instance().Get(place)->Wait(); + }); + m.def("run_cmd", [](const std::string &cmd, int time_out = -1, int sleep_inter = -1) -> const std::string { diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 6894b9d66e02f945e4478591bddb984ab3f946f1..932a54a8cb673d2ed17476a0d73444baf62477f5 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -88,6 +88,7 @@ from .dygraph.base import enable_dygraph, disable_dygraph, enable_imperative, di from .io import save, load, load_program_state, set_program_state from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.varbase_patch_methods import monkey_patch_varbase +from .core import _cuda_synchronize Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + \ @@ -128,7 +129,8 @@ __all__ = framework.__all__ + executor.__all__ + \ 'install_check', 'save', 'load', - 'VarBase' + 'VarBase', + '_cuda_synchronize' ] diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 7722a2a96a0b360529e30b403efa61226a158ff0..686ce225d0829298c13f300c474715f8ad026c7a 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -260,6 +260,7 @@ if avx_supported(): from .core_avx import _save_dygraph_dict from .core_avx import _load_dygraph_dict from .core_avx import _create_loaded_parameter + from .core_avx import _cuda_synchronize if sys.platform != 'win32': from .core_avx import _set_process_pids from .core_avx import _erase_process_pids @@ -304,6 +305,7 @@ if load_noavx: from .core_noavx import _save_dygraph_dict from .core_noavx import _load_dygraph_dict from .core_noavx import _create_loaded_parameter + from .core_noavx import _cuda_synchronize if sys.platform != 'win32': from .core_noavx import _set_process_pids from .core_noavx import _erase_process_pids