未验证 提交 87eaf8f9 编写于 作者: L Lev Kurilenko 提交者: GitHub

Check for local CUDA graphs when enable_cuda_graph=True (#2941)

上级 2ede0d94
......@@ -24,6 +24,8 @@ from deepspeed.accelerator import get_accelerator
from ..module_inject.policy import TransformerPolicy
from ..module_inject.auto_tp import AutoTP
from ..module_inject.replace_policy import generic_policies
DS_INFERENCE_ENABLED = False
from torch import nn
......@@ -155,6 +157,9 @@ class InferenceEngine(Module):
if config.tensor_parallel.tp_size > 1:
assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
# Check if local CUDA graphs can be created in replacement modules
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
self.module.register_forward_pre_hook(self._pre_forward_hook)
......@@ -512,6 +517,27 @@ class InferenceEngine(Module):
self._model_times = []
return model_times
def _module_match(self, module):
for policy in generic_policies:
policy = policy()
if policy.match_replaced(module):
return True
return False
def _local_cuda_graph_used(self, module):
if isinstance(module, torch.nn.Module):
return False
else:
sub_module_cuda_graph = False
for name in module.__dict__.keys():
sub_module = getattr(module, name)
if self._module_match(sub_module) and hasattr(sub_module,
"enable_cuda_graph"):
sub_module_cuda_graph = True
return sub_module_cuda_graph
def forward(self, *inputs, **kwargs):
"""Execute forward propagation
......@@ -525,7 +551,8 @@ class InferenceEngine(Module):
get_accelerator().synchronize()
start = time.time()
if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph:
if get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
......
......@@ -2,11 +2,12 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from ..features.cuda_graph import CUDAGraph
class DSUNet(torch.nn.Module):
class DSUNet(CUDAGraph, torch.nn.Module):
def __init__(self, unet, enable_cuda_graph=True):
super().__init__()
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.unet = unet
# SD pipeline accesses this attribute
self.in_channels = unet.in_channels
......@@ -17,7 +18,6 @@ class DSUNet(torch.nn.Module):
self.unet.requires_grad_(requires_grad=False)
self.unet.to(memory_format=torch.channels_last)
self.cuda_graph_created = False
self.enable_cuda_graph = enable_cuda_graph
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
......
......@@ -2,11 +2,12 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from ..features.cuda_graph import CUDAGraph
class DSVAE(torch.nn.Module):
class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__()
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae
self.device = self.vae.device
self.dtype = self.vae.dtype
......@@ -14,7 +15,6 @@ class DSVAE(torch.nn.Module):
self.decoder_cuda_graph_created = False
self.encoder_cuda_graph_created = False
self.all_cuda_graph_created = False
self.enable_cuda_graph = enable_cuda_graph
def _graph_replay_decoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
......@@ -104,7 +104,7 @@ class DSVAE(torch.nn.Module):
else:
return self._encode(*inputs, **kwargs)
def _graph_replay_all(self, *inputs, **kwargs):
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
......@@ -117,10 +117,10 @@ class DSVAE(torch.nn.Module):
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay_all(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay_all(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
......
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2023 The Microsoft DeepSpeed Team
'''
from abc import ABC, abstractmethod
class CUDAGraph(ABC):
def __init__(self, enable_cuda_graph=False):
super().__init__()
self.enable_cuda_graph = enable_cuda_graph
@abstractmethod
def _create_cuda_graph(self):
"""
Create CUDA graph(s)
"""
raise NotImplementedError
@abstractmethod
def _graph_replay(self):
"""
Replay CUDA graph(s)
"""
raise NotImplementedError
......@@ -3,11 +3,12 @@ Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph
class DSClipEncoder(torch.nn.Module):
class DSClipEncoder(CUDAGraph, torch.nn.Module):
def __init__(self, enc, enable_cuda_graph=False):
super().__init__()
super().__init__(enable_cuda_graph=enable_cuda_graph)
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
self.enc = enc
self.device = self.enc.device
......@@ -18,7 +19,6 @@ class DSClipEncoder(torch.nn.Module):
self.static_output = [None, None]
self._cuda_graphs = [None, None]
self.iter = 0
self.enable_cuda_graph = enable_cuda_graph
self.config = self.enc.config
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
......
......@@ -5,6 +5,7 @@ import torch
from torch.nn.parameter import Parameter
from ..policy import DSPolicy
from ...model_implementations.diffusers.unet import DSUNet
class UNetPolicy(DSPolicy):
......@@ -19,9 +20,11 @@ class UNetPolicy(DSPolicy):
def match(self, module):
return isinstance(module, self._orig_layer_class)
def match_replaced(self, module):
return isinstance(module, DSUNet)
def apply(self, module, enable_cuda_graph=True):
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
from ...model_implementations.diffusers.unet import DSUNet
return DSUNet(module, enable_cuda_graph=enable_cuda_graph)
def attention(self, client_module):
......
......@@ -2,6 +2,7 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
from ..policy import DSPolicy
from ...model_implementations.diffusers.vae import DSVAE
class VAEPolicy(DSPolicy):
......@@ -20,9 +21,11 @@ class VAEPolicy(DSPolicy):
def match(self, module):
return isinstance(module, self._orig_layer_class)
def match_replaced(self, module):
return isinstance(module, DSVAE)
def apply(self, module, enable_cuda_graph=True):
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
from ...model_implementations.diffusers.vae import DSVAE
return DSVAE(module, enable_cuda_graph=enable_cuda_graph)
# NOTE (lekurile): Should we have a diffusers policy class?
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册