From 5f7e4a213742404264006d1e14029f8b276f8539 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 27 Dec 2021 14:15:40 +0800 Subject: [PATCH] refine CUDA Graph (#38401) --- python/paddle/device/cuda/graphs.py | 89 ++++++++++++----------------- 1 file changed, 37 insertions(+), 52 deletions(-) diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py index 2a60aad2fd2..29e1b2694a6 100644 --- a/python/paddle/device/cuda/graphs.py +++ b/python/paddle/device/cuda/graphs.py @@ -17,56 +17,41 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA if is_compiled_with_cuda() and not is_compiled_with_rocm(): from paddle.fluid.core import CUDAGraph as CoreCUDAGraph - - class CUDAGraph: - def __init__(self, place=None, mode="thread_local"): - ALL_MODES = ["global", "thread_local", "relaxed"] - self._graph = None - if place is None: - device_id = int(os.environ.get('FLAGS_selected_gpus', 0)) - place = CUDAPlace(device_id) - self._place = place - assert mode in ALL_MODES - self._mode = ALL_MODES.index(mode) - - def capture_begin(self): - CoreCUDAGraph.begin_capture(self._place, self._mode) - - def capture_end(self): - self._graph = CoreCUDAGraph.end_capture() - - def replay(self): - self._graph.replay() - - def reset(self): - self._graph.reset() - - def print_to_dot_files(self, dirname, flags=None): - if not isinstance(dirname, (str, bytes)): - dirname = dirname.name - os.makedirs(name=dirname, exist_ok=True) - assert os.path.isdir( - dirname), "The dirname {} should be a directory".format(dirname) - if flags is None: - flags = 2047 # only all information. It can be any integer inside [1, 2048) - self._graph.print_to_dot_files(dirname, flags) else: - - class CUDAGraph: - def __init__(self, place=None, mode="thread_local"): - raise NotImplementedError() - - def capture_begin(self): - raise NotImplementedError() - - def capture_end(self): - raise NotImplementedError() - - def replay(self): - raise NotImplementedError() - - def reset(self): - raise NotImplementedError() - - def print_to_dot_files(self, dirname, flags=None): - raise NotImplementedError() + CoreCUDAGraph = None + + +class CUDAGraph: + def __init__(self, place=None, mode="thread_local"): + assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU." + + ALL_MODES = ["global", "thread_local", "relaxed"] + self._graph = None + if place is None: + device_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + place = CUDAPlace(device_id) + self._place = place + assert mode in ALL_MODES + self._mode = ALL_MODES.index(mode) + + def capture_begin(self): + CoreCUDAGraph.begin_capture(self._place, self._mode) + + def capture_end(self): + self._graph = CoreCUDAGraph.end_capture() + + def replay(self): + self._graph.replay() + + def reset(self): + self._graph.reset() + + def print_to_dot_files(self, dirname, flags=None): + if not isinstance(dirname, (str, bytes)): + dirname = dirname.name + os.makedirs(name=dirname, exist_ok=True) + assert os.path.isdir( + dirname), "The dirname {} should be a directory".format(dirname) + if flags is None: + flags = 2047 # only all information. It can be any integer inside [1, 2048) + self._graph.print_to_dot_files(dirname, flags) -- GitLab