未验证 提交 5f7e4a21 编写于 作者: S sneaxiy 提交者: GitHub

refine CUDA Graph (#38401)

上级 89d38f55
...@@ -17,9 +17,14 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA ...@@ -17,9 +17,14 @@ from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDA
if is_compiled_with_cuda() and not is_compiled_with_rocm(): if is_compiled_with_cuda() and not is_compiled_with_rocm():
from paddle.fluid.core import CUDAGraph as CoreCUDAGraph from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
else:
CoreCUDAGraph = None
class CUDAGraph: class CUDAGraph:
def __init__(self, place=None, mode="thread_local"): def __init__(self, place=None, mode="thread_local"):
assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
ALL_MODES = ["global", "thread_local", "relaxed"] ALL_MODES = ["global", "thread_local", "relaxed"]
self._graph = None self._graph = None
if place is None: if place is None:
...@@ -50,23 +55,3 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm(): ...@@ -50,23 +55,3 @@ if is_compiled_with_cuda() and not is_compiled_with_rocm():
if flags is None: if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048) flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags) self._graph.print_to_dot_files(dirname, flags)
else:
class CUDAGraph:
def __init__(self, place=None, mode="thread_local"):
raise NotImplementedError()
def capture_begin(self):
raise NotImplementedError()
def capture_end(self):
raise NotImplementedError()
def replay(self):
raise NotImplementedError()
def reset(self):
raise NotImplementedError()
def print_to_dot_files(self, dirname, flags=None):
raise NotImplementedError()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册