未验证 提交 b2a724e2 编写于 作者: L lekurile 提交者: GitHub

Add TestInjectionPolicy inference unittest class for testing custom injection policies (#2426)

This PR adds a TestInjectionPolicy inference unittest class for testing custom injection policies.

This test differs from the existing tests in that the injection_policy dictionary is explicitly specified when calling the DeepSpeed init_inference API.

The google/t5-v1_1-small text2text-generation model and the roberta-large fill-mask model are added as tests with the injection policy explicitly specified.

This is done to expand our unittest coverage to test the path where the replace_wo_policy function is invoked (see GH-2387).
Co-authored-by: NLev Kurilenko <lekurile@microsoft.com>
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
上级 1b7c6791
google
lm-eval>=0.2.0
protobuf
transformers
transformers[sentencepiece]
......@@ -2,10 +2,14 @@
import sys
import pytest
import os
from os.path import abspath, dirname, join
import torch
import warnings
# Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small)
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
......
......@@ -9,6 +9,8 @@ from unit.common import DistributedTest
from packaging import version as pkg_version
from deepspeed.ops.op_builder import OpBuilder
from transformers import pipeline
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.roberta.modeling_roberta import RobertaLayer
from huggingface_hub import HfApi
rocm_version = OpBuilder.installed_rocm_version()
......@@ -55,6 +57,7 @@ test_tasks = [
"text-classification",
"token-classification",
"text-generation",
"text2text-generation",
]
pytest.all_models = {
task: [m.modelId for m in _all_models if m.pipeline_tag == task]
......@@ -150,6 +153,8 @@ def query(model_w_task):
return "My name is jean-baptiste and I live in montreal."
elif task == "text-generation":
return "DeepSpeed is the greatest"
elif task == "text2text-generation":
return "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
else:
NotImplementedError(f'query for task "{task}" is not implemented')
......@@ -187,6 +192,11 @@ def text_generation_assert(x, y):
for res in y)
def text2text_generation_assert(x, y):
return set(res["generated_text"] for res in x) == set(res["generated_text"]
for res in y)
@pytest.fixture
def assert_fn(model_w_task):
model, task = model_w_task
......@@ -196,6 +206,7 @@ def assert_fn(model_w_task):
"text-classification": text_classification_assert,
"token-classification": token_classification_assert,
"text-generation": text_generation_assert,
"text2text-generation": text2text_generation_assert,
}
assert_fn = assert_fn_dict.get(task, None)
if assert_fn is None:
......@@ -323,6 +334,67 @@ class TestMPSize(DistributedTest):
assert assert_fn(bs_output, ds_output)
@pytest.mark.seq_inference
@pytest.mark.parametrize(
"model_w_task, injection_policy",
[
(("google/t5-v1_1-small",
"text2text-generation"),
{
T5Block: ('SelfAttention.o',
'EncDecAttention.o',
'DenseReluDense.wo')
}),
(("roberta-large",
"fill-mask"),
{
RobertaLayer: ('output.dense')
}),
],
ids=["t5",
"roberta"],
)
@pytest.mark.parametrize("dtype", [torch.float], ids=["fp32"])
@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"])
class TestInjectionPolicy(DistributedTest):
world_size = [1, 2]
def test(
self,
model_w_task,
injection_policy,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "2"))
# We have to load these large models on CPU with pipeline because not
# enough GPU memory
pipe = pipeline(task, model=model, device=-1, framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
injection_policy=injection_policy)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)
print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)
@pytest.mark.nightly
@pytest.mark.parametrize(
"model_family, model_name",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册