未验证 提交 4b6d7c15 编写于 作者: M Michael Wyatt 提交者: GitHub

Disable Stage 1&2 CPUAdam pathways (#3097)

* disable CPUAdam pathways in optimizer copy/step

* Update stage_1_and_2.py

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 9726bd46
......@@ -1622,11 +1622,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]]
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
else:
self.optimizer.step()
# Disabling this as the C++ side copy & synchornize is not working correctly
#from deepspeed.ops.adam import DeepSpeedCPUAdam
#if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
# self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
#else:
# self.optimizer.step()
self.optimizer.step()
self.optimizer.param_groups = original_param_groups
def step(self, closure=None):
......@@ -1677,11 +1679,15 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)
from deepspeed.ops.adam import DeepSpeedCPUAdam
if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half):
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
# Disabled, this is not currently working
#from deepspeed.ops.adam import DeepSpeedCPUAdam
#if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half):
# bit16_partitions = self.parallel_partitioned_bit16_groups[i]
# fp32_partition = self.single_partition_of_fp32_groups[i]
# bit16_partitions[partition_id].data.copy_(fp32_partition.data)
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP])
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册