From 6b6d40902ed12f105795fca9d954638f2c7348f8 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 6 Jul 2023 14:37:54 +0800 Subject: [PATCH] update recompute tunint unittest support other devices (#55165) --- test/auto_parallel/test_tuning_recompute.py | 27 +++++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/test/auto_parallel/test_tuning_recompute.py b/test/auto_parallel/test_tuning_recompute.py index 7cbda9a75bc..ef9a16a2cae 100644 --- a/test/auto_parallel/test_tuning_recompute.py +++ b/test/auto_parallel/test_tuning_recompute.py @@ -36,7 +36,7 @@ def generate_model(): gpt = GPTModel( vocab_size=50304, hidden_size=1024, - num_hidden_layers=14, + num_hidden_layers=13, num_attention_heads=16, intermediate_size=1024 * 4, hidden_act="gelu", @@ -95,14 +95,25 @@ class TestRecomputePassTuning(unittest.TestCase): engine = auto.Engine(model, loss, opt, strategy=strategy) engine._tune(self.dataset, 3, batch_size=self.batch_size) - assert ( - len( - engine._dist_contexts[ - 'train' - ].strategy.recompute.no_recompute_segments - ) - > 0 + gpu_memory_size = round( + paddle.device.cuda.get_device_properties(0).total_memory + / 1024 + / 1024 + / 1024 ) + dist_strategy = engine._dist_contexts['train'].strategy + if gpu_memory_size in [16, 32]: + self.assertGreater( + len(dist_strategy.recompute.no_recompute_segments), + 0, + "When GPU memory size is 16G or 32G, the length of no_recompute_segments should be greater than 0.", + ) + elif gpu_memory_size >= 40: + self.assertEqual( + dist_strategy.recompute.enable, + False, + "When GPU memory size is greater than 40GB, the recompute strategy should be disable.", + ) if __name__ == "__main__": -- GitLab