未验证 提交 e772f166 编写于 作者: M Michael Wyatt 提交者: GitHub

Use CUDA events for inference model profiling (#2371)

* use cuda event timers for model profiling
上级 8da0238b
......@@ -11,6 +11,7 @@ from deepspeed.utils.logging import log_dist
from torch.nn.modules import Module
from packaging import version as pkg_version
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
......@@ -25,6 +26,8 @@ from ..module_inject.replace_policy import DSPolicy
DS_INFERENCE_ENABLED = False
from torch import nn
INFERENCE_MODEL_TIMER = "model-forward-inference"
class InferenceEngine(Module):
inference_mp_group = None
......@@ -168,11 +171,14 @@ class InferenceEngine(Module):
if self.mp_world_size > 1:
assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
def profile_model_time(self):
def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self.enable_cuda_graph:
self.module.register_forward_pre_hook(self._pre_forward_hook)
self.module.register_forward_hook(self._post_forward_hook)
self.model_profile_enabled = True
self.use_cuda_events = use_cuda_events
if self.use_cuda_events:
self.timers = SynchronizedWallClockTimer()
def _get_model_config_generate(self, config):
self.config = getattr(self.module, 'config', None) if config is None else config
......@@ -184,13 +190,21 @@ class InferenceEngine(Module):
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
def _pre_forward_hook(self, module, *inputs, **kwargs):
torch.cuda.synchronize()
self._start = time.time()
if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).start()
else:
torch.cuda.synchronize()
self._start = time.time()
def _post_forward_hook(self, module, input, output):
torch.cuda.synchronize()
self._end = time.time()
self._model_times.append(self._end - self._start)
if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).stop()
elapsed_time = self.timers(INFERENCE_MODEL_TIMER).elapsed(reset=True)
else:
torch.cuda.synchronize()
self._end = time.time()
elapsed_time = self._end - self._start
self._model_times.append(elapsed_time)
def _create_model_parallel_group(self):
# Call the init process
......
import os
import time
import pytest
import torch
import deepspeed
from transformers import pipeline
from unit.common import DistributedTest
@pytest.fixture
def query(model, task):
if task == "text-generation":
return "DeepSpeed is"
elif task == "fill-mask":
if "roberta" in model:
return "I am a <mask> model"
else:
return "I am a [MASK] model"
else:
raise NotImplementedError
@pytest.fixture
def inf_kwargs(task):
if task == "text-generation":
return {"do_sample": False, "min_length": 50, "max_length": 50}
else:
return {}
@pytest.mark.inference
@pytest.mark.parametrize("model,task",
[
("bert-base-cased",
"fill-mask"),
("roberta-base",
"fill-mask"),
("gpt2",
"text-generation"),
("facebook/opt-125m",
"text-generation"),
("bigscience/bloom-560m",
"text-generation"),
])
@pytest.mark.parametrize("cuda_graphs", [True, False])
@pytest.mark.parametrize("use_cuda_events", [True, False])
class TestModelProfiling(DistributedTest):
world_size = 1
def test(self,
model,
task,
query,
inf_kwargs,
cuda_graphs,
use_cuda_events,
dtype=torch.float16):
if cuda_graphs and "bert" not in model:
pytest.skip(f"CUDA Graph not supported for {model}")
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
pipe = pipeline(task, model, framework="pt", device=local_rank)
pipe.model = deepspeed.init_inference(pipe.model,
dtype=dtype,
mp_size=world_size,
replace_with_kernel_inject=True,
replace_method="auto",
enable_cuda_graph=cuda_graphs)
pipe.model.profile_model_time(use_cuda_events=use_cuda_events)
e2e_times = []
model_times = []
for _ in range(10):
torch.cuda.synchronize()
start = time.perf_counter_ns()
r = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
end = time.perf_counter_ns()
e2e_times.append((end - start) / 1e6) # convert ns to ms
model_times.extend(pipe.model.model_times())
for e2e_t, model_t in zip(e2e_times, model_times):
assert e2e_t >= model_t
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册