diff --git a/tests/unit/common.py b/tests/unit/common.py index 7214a2a9d8fcf18bef1d660112864df79fc048d4..3fb335318fde702ec045b8bb367ef8e760f30caa 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -22,9 +22,6 @@ import pytest from _pytest.outcomes import Skipped from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker -# Worker timeout *after* the first worker has completed. -DEEPSPEED_UNIT_WORKER_TIMEOUT = 120 - # Worker timeout for tests that hang DEEPSPEED_TEST_TIMEOUT = 600 @@ -114,6 +111,7 @@ class DistributedExec(ABC): requires_cuda_env = True reuse_dist_env = False _pool_cache = {} + exec_timeout = DEEPSPEED_TEST_TIMEOUT @abstractmethod def run(self): @@ -170,7 +168,7 @@ class DistributedExec(ABC): skip_msgs_async = pool.starmap_async(self._dist_run, args) try: - skip_msgs = skip_msgs_async.get(DEEPSPEED_TEST_TIMEOUT) + skip_msgs = skip_msgs_async.get(self.exec_timeout) except mp.TimeoutError: # Shortcut to exit pytest in the case of a hanged test. This # usually means an environment error and the rest of tests will diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 49fea03e11f60c8bef0ad999d657cd3d9c2c8702..21407f0163ff74ede80a880d94affc5d7c66a654 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -521,13 +521,14 @@ class TestAutoTensorParallelism(DistributedTest): "model_family, model_name", ( ["gpt2", "EleutherAI/gpt-neo-2.7B"], - ["gpt2", "EleutherAI/gpt-j-6b"], + #["gpt2", "EleutherAI/gpt-j-6b"], # Causing OOM for this test ["gpt2", "gpt2-xl"], ), ) @pytest.mark.parametrize("task", ["lambada_standard"]) class TestLMCorrectness(DistributedTest): world_size = 1 + exec_timeout = 1200 # Give these tests longer to complete def test(self, model_family, model_name, task): # imports here to avoid import errors when pytest collects tests @@ -536,6 +537,21 @@ class TestLMCorrectness(DistributedTest): import lm_eval.tasks import lm_eval.evaluator + # The bootstrap_stderr function in lm_eval.metrics uses a + # multiprocessing Pool to increase performance. Since we use a Pool for + # our distributed tests and cannot nest Pools, we must redefine and + # patch this function with a version that does not use Pool. + def no_pool_bootstrap_stderr(f, xs, iters): + from lm_eval.metrics import _bootstrap_internal + from lm_eval.metrics import sample_stddev + res = [] + chunk_size = min(1000, iters) + for i in range(iters // chunk_size): + res.extend(_bootstrap_internal(f, chunk_size)((i, xs))) + return sample_stddev(res) + + lm_eval.metrics.bootstrap_stderr = no_pool_bootstrap_stderr + local_rank = os.getenv("LOCAL_RANK", "0") device = torch.device(get_accelerator().device_name(local_rank)) dtype = torch.float @@ -557,6 +573,7 @@ class TestLMCorrectness(DistributedTest): get_accelerator().synchronize() bs_time = time.time() - start + getattr(lm, model_family).to("cpu") ds_model = deepspeed.init_inference( getattr(lm, model_family), mp_size=1,