未验证 提交 d0dbc95a 编写于 作者: 郭叶军's avatar 郭叶军 提交者: GitHub

call empty_cache to really free up GPU memory as described in comment (#2620)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 b08cf416
......@@ -814,6 +814,10 @@ def get_ma_status():
return torch.cuda.memory_allocated()
def empty_cache():
torch.cuda.empty_cache()
def see_memory_usage(message, force=False):
if not force:
return
......
......@@ -13,6 +13,7 @@ from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank,
get_global_norm,
empty_cache,
see_memory_usage,
is_model_parallel_parameter,
align_dense_tensors,
......@@ -295,6 +296,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.bit16_groups[i])
empty_cache()
see_memory_usage(f"After moving param group {i} to CPU", force=False)
# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册