diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index ffb0b6a23d7499c4fe69f493a09d5f20f05910f1..911b43bfc9bda6f0ebe6b4c1f51246a11bf7729a 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -490,10 +490,11 @@ class DeepSpeedZeRoOffload(object): # post backward hook self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) - @torch.no_grad() def pre_sub_module_forward_function(self, sub_module): see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) - + prev_grad_state = torch.is_grad_enabled( + ) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled + torch.set_grad_enabled(False) global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) @@ -501,8 +502,8 @@ class DeepSpeedZeRoOffload(object): param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) - param_coordinator.fetch_sub_module(sub_module, forward=True) - + param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state) + torch.set_grad_enabled(prev_grad_state) see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) @torch.no_grad()