diff --git a/test/auto_parallel/test_tuning_recompute.py b/test/auto_parallel/test_tuning_recompute.py index 7cbda9a75bc0f97be1c20ddebed63c643d31ba6f..ef9a16a2cae72cbd64c92e46f1f9acf8c51b0a7e 100644 --- a/test/auto_parallel/test_tuning_recompute.py +++ b/test/auto_parallel/test_tuning_recompute.py @@ -36,7 +36,7 @@ def generate_model(): gpt = GPTModel( vocab_size=50304, hidden_size=1024, - num_hidden_layers=14, + num_hidden_layers=13, num_attention_heads=16, intermediate_size=1024 * 4, hidden_act="gelu", @@ -95,14 +95,25 @@ class TestRecomputePassTuning(unittest.TestCase): engine = auto.Engine(model, loss, opt, strategy=strategy) engine._tune(self.dataset, 3, batch_size=self.batch_size) - assert ( - len( - engine._dist_contexts[ - 'train' - ].strategy.recompute.no_recompute_segments - ) - > 0 + gpu_memory_size = round( + paddle.device.cuda.get_device_properties(0).total_memory + / 1024 + / 1024 + / 1024 ) + dist_strategy = engine._dist_contexts['train'].strategy + if gpu_memory_size in [16, 32]: + self.assertGreater( + len(dist_strategy.recompute.no_recompute_segments), + 0, + "When GPU memory size is 16G or 32G, the length of no_recompute_segments should be greater than 0.", + ) + elif gpu_memory_size >= 40: + self.assertEqual( + dist_strategy.recompute.enable, + False, + "When GPU memory size is greater than 40GB, the recompute strategy should be disable.", + ) if __name__ == "__main__":