未验证 提交 5f7e4a21 编写于 作者: S sneaxiy 提交者: GitHub

refine CUDA Graph (#38401)

上级 89d38f55
......@@ -17,9 +17,14 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA
if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
else:
CoreCUDAGraph = None
class CUDAGraph:
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None
if place is None:
......@@ -50,23 +55,3 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm():
if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags)
else:
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
raise NotImplementedError()
def capture_begin(self):
raise NotImplementedError()
def capture_end(self):
raise NotImplementedError()
def replay(self):
raise NotImplementedError()
def reset(self):
raise NotImplementedError()
def print_to_dot_files(self, dirname, flags=None):
raise NotImplementedError()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册