未验证 提交 5f7e4a21 编写于 作者: S sneaxiy 提交者: GitHub

refine CUDA Graph (#38401)

上级 89d38f55
......@@ -17,56 +17,41 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA
if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None
if place is None:
device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = CUDAPlace(device_id)
self._place = place
assert mode in ALL_MODES
self._mode = ALL_MODES.index(mode)
def capture_begin(self):
CoreCUDAGraph.begin_capture(self._place, self._mode)
def capture_end(self):
self._graph = CoreCUDAGraph.end_capture()
def replay(self):
self._graph.replay()
def reset(self):
self._graph.reset()
def print_to_dot_files(self, dirname, flags=None):
if not isinstance(dirname, (str, bytes)):
dirname = dirname.name
os.makedirs(name=dirname, exist_ok=True)
assert os.path.isdir(
dirname), "The dirname {} should be a directory".format(dirname)
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
else:
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
raise NotImplementedError()
def capture_begin(self):
raise NotImplementedError()
def capture_end(self):
raise NotImplementedError()
def replay(self):
raise NotImplementedError()
def reset(self):
raise NotImplementedError()
def print_to_dot_files(self, dirname, flags=None):
raise NotImplementedError()
CoreCUDAGraph = None
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None
if place is None:
device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = CUDAPlace(device_id)
self._place = place
assert mode in ALL_MODES
self._mode = ALL_MODES.index(mode)
def capture_begin(self):
CoreCUDAGraph.begin_capture(self._place, self._mode)
def capture_end(self):
self._graph = CoreCUDAGraph.end_capture()
def replay(self):
self._graph.replay()
def reset(self):
self._graph.reset()
def print_to_dot_files(self, dirname, flags=None):
if not isinstance(dirname, (str, bytes)):
dirname = dirname.name
os.makedirs(name=dirname, exist_ok=True)
assert os.path.isdir(
dirname), "The dirname {} should be a directory".format(dirname)
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册