未验证 提交 cf638be9 编写于 作者: J Jeff Rasley 提交者: GitHub

only override forward if using cuda-graph (#2291)

上级 95d11517
...@@ -40,8 +40,6 @@ jobs: ...@@ -40,8 +40,6 @@ jobs:
run: | run: |
git clone https://github.com/huggingface/transformers git clone https://github.com/huggingface/transformers
cd transformers cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout v4.21.2
git rev-parse --short HEAD git rev-parse --short HEAD
pip uninstall --yes transformers pip uninstall --yes transformers
pip install . pip install .
......
...@@ -162,10 +162,7 @@ class InferenceEngine(Module): ...@@ -162,10 +162,7 @@ class InferenceEngine(Module):
torch.cuda.set_rng_state(_rng_state.cpu()) torch.cuda.set_rng_state(_rng_state.cpu())
if self.mp_world_size > 1: if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
self.module.forward = self.forward
else:
self.module.register_forward_pre_hook(self._pre_forward_hook)
def _get_model_config_generate(self, config): def _get_model_config_generate(self, config):
self.config = getattr(self.module, 'config', None) if config is None else config self.config = getattr(self.module, 'config', None) if config is None else config
...@@ -475,14 +472,6 @@ class InferenceEngine(Module): ...@@ -475,14 +472,6 @@ class InferenceEngine(Module):
elif self.dtype == torch.float: elif self.dtype == torch.float:
self.module.float() self.module.float()
def _pre_forward_hook(self, module, *inputs, **kwargs):
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
def _create_cuda_graph(self, *inputs, **kwargs): def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle # warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream() cuda_stream = torch.cuda.Stream()
...@@ -519,23 +508,6 @@ class InferenceEngine(Module): ...@@ -519,23 +508,6 @@ class InferenceEngine(Module):
*inputs: Variable length input list *inputs: Variable length input list
**kwargs: variable length keyword arguments **kwargs: variable length keyword arguments
""" """
if self.mp_world_size > 1:
if self.mpu is None:
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
if not input.is_contiguous():
input = input.contiguous()
dist.broadcast(input, 0)
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
if not kwargs[k].is_contiguous():
kwargs[k] = kwargs[k].contiguous()
dist.broadcast(kwargs[k], 0)
outputs = self.model_orig_fwd(*inputs, **kwargs)
else:
if self.enable_cuda_graph: if self.enable_cuda_graph:
if self.cuda_graph_created: if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs) outputs = self._graph_replay(*inputs, **kwargs)
...@@ -544,5 +516,5 @@ class InferenceEngine(Module): ...@@ -544,5 +516,5 @@ class InferenceEngine(Module):
outputs = self._graph_replay(*inputs, **kwargs) outputs = self._graph_replay(*inputs, **kwargs)
else: else:
outputs = self.module(*inputs, **kwargs) outputs = self.module(*inputs, **kwargs)
#outputs = self.module(*inputs, **kwargs)
return outputs return outputs
...@@ -292,13 +292,13 @@ class TestModelTask(DistributedTest): ...@@ -292,13 +292,13 @@ class TestModelTask(DistributedTest):
@pytest.mark.seq_inference @pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task", @pytest.mark.parametrize("model_w_task",
[("gpt2", [("EleutherAI/gpt-neo-1.3B",
"text-generation"), "text-generation"),
("EleutherAI/gpt-neox-20b", ("EleutherAI/gpt-neox-20b",
"text-generation"), "text-generation"),
("bigscience/bloom-3b", ("bigscience/bloom-3b",
"text-generation")], "text-generation")],
ids=["gpt2", ids=["gpt-neo",
"gpt-neox", "gpt-neox",
"bloom"]) "bloom"])
class TestMPSize(DistributedTest): class TestMPSize(DistributedTest):
...@@ -308,7 +308,6 @@ class TestMPSize(DistributedTest): ...@@ -308,7 +308,6 @@ class TestMPSize(DistributedTest):
self, self,
model_w_task, model_w_task,
dtype, dtype,
enable_cuda_graph,
query, query,
inf_kwargs, inf_kwargs,
assert_fn, assert_fn,
...@@ -325,14 +324,11 @@ class TestMPSize(DistributedTest): ...@@ -325,14 +324,11 @@ class TestMPSize(DistributedTest):
pipe = pipeline(task, model=model, device=-1, framework="pt") pipe = pipeline(task, model=model, device=-1, framework="pt")
bs_output = pipe(query, **inf_kwargs) bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference( pipe.model = deepspeed.init_inference(pipe.model,
pipe.model,
mp_size=self.world_size, mp_size=self.world_size,
dtype=dtype, dtype=dtype,
replace_method="auto", replace_method="auto",
replace_with_kernel_inject=True, replace_with_kernel_inject=True)
enable_cuda_graph=enable_cuda_graph,
)
# Switch device to GPU so that input tensors are not on CPU # Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}") pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs) ds_output = pipe(query, **inf_kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册