From 384f17b05bdf2c73fa6b5da66399d8c848779e89 Mon Sep 17 00:00:00 2001 From: Rahil Bathwal <87332510+rahilbathwal5@users.noreply.github.com> Date: Tue, 13 Dec 2022 18:44:17 -0800 Subject: [PATCH] [deepspeed/autotuner] Bug fix for binary search for batch size (#2162) * bug fix for binary search for batch size * fix binary search termination condition --- deepspeed/autotuning/autotuner.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index 98be83db..971aa257 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -980,11 +980,10 @@ class Autotuner: low = min_micro_batch_size high = max_micro_batch_size - while low < high: + # binary search until low is the smallest micro batch size that OOMs. + while low <= high: mid = int((low + high) // 2) logger.debug(f"trying mbs = {mid}, low = {low}, high = {high}") - if mid == low: - break if mid not in used_micro_batch_sizes: ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mid ds_config[TRAIN_BATCH_SIZE] = mid * gas * \ @@ -992,7 +991,7 @@ class Autotuner: exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mid) exp, metric_val = self.run_ds_config(ds_config, exp_name) if metric_val: - low = mid + low = mid + 1 self.update_records(tuning_space_name, exp, metric_val, 1) used_micro_batch_sizes.append(mid) if prev_metric_val and ((metric_val - prev_metric_val) / @@ -1004,8 +1003,8 @@ class Autotuner: self.update_records(tuning_space_name, exp, 0, 1) high = mid - 1 else: - low = mid - max_micro_batch_size = low + low = mid + 1 + max_micro_batch_size = low - 1 logger.info( f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}." -- GitLab