未验证 提交 384f17b0 编写于 作者: R Rahil Bathwal 提交者: GitHub

[deepspeed/autotuner] Bug fix for binary search for batch size (#2162)

* bug fix for binary search for batch size

* fix binary search termination condition
上级 3a3dfe66
......@@ -980,11 +980,10 @@ class Autotuner:
low = min_micro_batch_size
high = max_micro_batch_size
while low < high:
# binary search until low is the smallest micro batch size that OOMs.
while low <= high:
mid = int((low + high) // 2)
logger.debug(f"trying mbs = {mid}, low = {low}, high = {high}")
if mid == low:
break
if mid not in used_micro_batch_sizes:
ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mid
ds_config[TRAIN_BATCH_SIZE] = mid * gas * \
......@@ -992,7 +991,7 @@ class Autotuner:
exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(mid)
exp, metric_val = self.run_ds_config(ds_config, exp_name)
if metric_val:
low = mid
low = mid + 1
self.update_records(tuning_space_name, exp, metric_val, 1)
used_micro_batch_sizes.append(mid)
if prev_metric_val and ((metric_val - prev_metric_val) /
......@@ -1004,8 +1003,8 @@ class Autotuner:
self.update_records(tuning_space_name, exp, 0, 1)
high = mid - 1
else:
low = mid
max_micro_batch_size = low
low = mid + 1
max_micro_batch_size = low - 1
logger.info(
f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册