From d9a889d55921e1300ca21b26c85d63e4de8e3a74 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 22 Aug 2023 09:46:08 -0700 Subject: [PATCH] Fix nv-nightly workflow (#4163) * Disable nv-nightly workflow since it doesn't work * Run on PRs to debug * fix for nv-nightly * fix * OOM fix? * Update nv-nightly.yml --------- Co-authored-by: Logan Adams --- tests/unit/common.py | 6 ++---- tests/unit/inference/test_inference.py | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/unit/common.py b/tests/unit/common.py index 7214a2a9..3fb33531 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -22,9 +22,6 @@ import pytest from _pytest.outcomes import Skipped from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker -# Worker timeout *after* the first worker has completed. -DEEPSPEED_UNIT_WORKER_TIMEOUT = 120 - # Worker timeout for tests that hang DEEPSPEED_TEST_TIMEOUT = 600 @@ -114,6 +111,7 @@ class DistributedExec(ABC): requires_cuda_env = True reuse_dist_env = False _pool_cache = {} + exec_timeout = DEEPSPEED_TEST_TIMEOUT @abstractmethod def run(self): @@ -170,7 +168,7 @@ class DistributedExec(ABC): skip_msgs_async = pool.starmap_async(self._dist_run, args) try: - skip_msgs = skip_msgs_async.get(DEEPSPEED_TEST_TIMEOUT) + skip_msgs = skip_msgs_async.get(self.exec_timeout) except mp.TimeoutError: # Shortcut to exit pytest in the case of a hanged test. This # usually means an environment error and the rest of tests will diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 49fea03e..21407f01 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -521,13 +521,14 @@ class TestAutoTensorParallelism(DistributedTest): "model_family, model_name", ( ["gpt2", "EleutherAI/gpt-neo-2.7B"], - ["gpt2", "EleutherAI/gpt-j-6b"], + #["gpt2", "EleutherAI/gpt-j-6b"], # Causing OOM for this test ["gpt2", "gpt2-xl"], ), ) @pytest.mark.parametrize("task", ["lambada_standard"]) class TestLMCorrectness(DistributedTest): world_size = 1 + exec_timeout = 1200 # Give these tests longer to complete def test(self, model_family, model_name, task): # imports here to avoid import errors when pytest collects tests @@ -536,6 +537,21 @@ class TestLMCorrectness(DistributedTest): import lm_eval.tasks import lm_eval.evaluator + # The bootstrap_stderr function in lm_eval.metrics uses a + # multiprocessing Pool to increase performance. Since we use a Pool for + # our distributed tests and cannot nest Pools, we must redefine and + # patch this function with a version that does not use Pool. + def no_pool_bootstrap_stderr(f, xs, iters): + from lm_eval.metrics import _bootstrap_internal + from lm_eval.metrics import sample_stddev + res = [] + chunk_size = min(1000, iters) + for i in range(iters // chunk_size): + res.extend(_bootstrap_internal(f, chunk_size)((i, xs))) + return sample_stddev(res) + + lm_eval.metrics.bootstrap_stderr = no_pool_bootstrap_stderr + local_rank = os.getenv("LOCAL_RANK", "0") device = torch.device(get_accelerator().device_name(local_rank)) dtype = torch.float @@ -557,6 +573,7 @@ class TestLMCorrectness(DistributedTest): get_accelerator().synchronize() bs_time = time.time() - start + getattr(lm, model_family).to("cpu") ds_model = deepspeed.init_inference( getattr(lm, model_family), mp_size=1, -- GitLab