未验证 提交 cf638be9 编写于 作者: J Jeff Rasley 提交者: GitHub

only override forward if using cuda-graph (#2291)

上级 95d11517
......@@ -40,8 +40,6 @@ jobs:
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout v4.21.2
git rev-parse --short HEAD
pip uninstall --yes transformers
pip install .
......
......@@ -162,10 +162,7 @@ class InferenceEngine(Module):
torch.cuda.set_rng_state(_rng_state.cpu())
if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward
self.module.forward = self.forward
else:
self.module.register_forward_pre_hook(self._pre_forward_hook)
assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
def _get_model_config_generate(self, config):
self.config = getattr(self.module, 'config', None) if config is None else config
......@@ -475,14 +472,6 @@ class InferenceEngine(Module):
elif self.dtype == torch.float:
self.module.float()
def _pre_forward_hook(self, module, *inputs, **kwargs):
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
......@@ -519,30 +508,13 @@ class InferenceEngine(Module):
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if self.mp_world_size > 1:
if self.mpu is None:
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
if not input.is_contiguous():
input = input.contiguous()
dist.broadcast(input, 0)
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
if not kwargs[k].is_contiguous():
kwargs[k] = kwargs[k].contiguous()
dist.broadcast(kwargs[k], 0)
outputs = self.model_orig_fwd(*inputs, **kwargs)
else:
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)
#outputs = self.module(*inputs, **kwargs)
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)
return outputs
......@@ -292,13 +292,13 @@ class TestModelTask(DistributedTest):
@pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task",
[("gpt2",
[("EleutherAI/gpt-neo-1.3B",
"text-generation"),
("EleutherAI/gpt-neox-20b",
"text-generation"),
("bigscience/bloom-3b",
"text-generation")],
ids=["gpt2",
ids=["gpt-neo",
"gpt-neox",
"bloom"])
class TestMPSize(DistributedTest):
......@@ -308,7 +308,6 @@ class TestMPSize(DistributedTest):
self,
model_w_task,
dtype,
enable_cuda_graph,
query,
inf_kwargs,
assert_fn,
......@@ -325,14 +324,11 @@ class TestMPSize(DistributedTest):
pipe = pipeline(task, model=model, device=-1, framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=self.world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
enable_cuda_graph=enable_cuda_graph,
)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=self.world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册