未验证 提交 462def45 编写于 作者: H Heyang Qin 提交者: GitHub

Enable hpz when running with torch.no_grad (#4232)

* enable hpz when running with torch.no_grad

* change the way to detect no_grad

* fix format

---------
Co-authored-by: NAmmar Ahmad Awan <ammar.awan@microsoft.com>
上级 6cbf6661
......@@ -490,10 +490,11 @@ class DeepSpeedZeRoOffload(object):
# post backward hook
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
prev_grad_state = torch.is_grad_enabled(
) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
torch.set_grad_enabled(False)
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)
......@@ -501,8 +502,8 @@ class DeepSpeedZeRoOffload(object):
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=True)
param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
torch.set_grad_enabled(prev_grad_state)
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
@torch.no_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册