未验证 提交 1d295ff5 编写于 作者: J Jeff Rasley 提交者: GitHub

Refactor ZeRO naming to reduce confusion (#1607)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 07887f66
......@@ -38,7 +38,7 @@ jobs:
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
nv-torch18-v100:
runs-on: [self-hosted, nvidia, torch18, v100]
......@@ -65,7 +65,7 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
nv-transformers-v100:
runs-on: [self-hosted, nvidia, torch18, v100]
......@@ -99,4 +99,4 @@ jobs:
pip install .[testing]
# find reqs used in ds integration tests
find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec pip install -r {} \;
TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --durations=0 --verbose tests/deepspeed
TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --color=yes --durations=0 --verbose tests/deepspeed
......@@ -22,8 +22,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
from typing import Callable, Dict, Optional, Union, Iterable
from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import (
is_zero_supported_optimizer,
......@@ -1326,28 +1325,12 @@ class DeepSpeedEngine(Module):
if optimizer is None:
optimizer = DummyOptim(list(self.module.parameters()))
if self.zero_legacy_stage1(
) and zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert not self.has_moe_layers, "MoE not supported with Stage 1"
assert not isinstance(optimizer, DummyOptim), "zero stage 1 requires an optimizer"
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
all_gather_partitions=self.zero_allgather_partitions(),
allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_predivide=self.gradient_predivide,
if self.zero_legacy_stage1():
raise Exception(
"The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO."
)
elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
if zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
overlap_comm = self.zero_overlap_comm()
contiguous_gradients = self.zero_contiguous_gradients()
round_robin_gradients = self.zero_round_robin_gradients()
......@@ -1366,7 +1349,7 @@ class DeepSpeedEngine(Module):
)
overlap_comm = False
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer = DeepSpeedZeroOptimizer(
optimizer,
timers=timers,
static_loss_scale=self.loss_scale(),
......@@ -1399,9 +1382,9 @@ class DeepSpeedEngine(Module):
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
assert not self.has_moe_layers, "MoE not supported with Stage 3"
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=timers,
......
此差异已折叠。
......@@ -592,7 +592,7 @@ class PostBackwardFunction(torch.autograd.Function):
INITIAL_MICRO_STEP_ID = -1
class FP16_DeepSpeedZeroOptimizer_Stage3(object):
class DeepSpeedZeroOptimizer_Stage3(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
......
......@@ -70,7 +70,7 @@ def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
class FP16_DeepSpeedZeroOptimizer(object):
class DeepSpeedZeroOptimizer(object):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
......@@ -2135,8 +2135,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
ckpt_version = state_dict_list[0].get("ds_version", False)
error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
"with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
"please set 'legacy_stage1': true in your zero config json. This old version of " \
"stage 1 will be removed in v0.4.0."
"please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json."
assert ckpt_version, f"Empty ds_version! {error_str}"
assert required_version <= pkg_version.parse(ckpt_version), f"Old version: {ckpt_version} {error_str}"
......
......@@ -241,7 +241,7 @@ Example of <i>**scheduler**</i>
**Note:** this mode cannot be combined with the `fp16` mode described above.
{: .notice--warning}
**Note:** this mode is only compatible with ZeRO stage 2.
**Note:** this mode is only compatible with ZeRO stages 1 and 2.
{: .notice--warning}
<i>**bfloat16**</i>: [dictionary]
......
......@@ -3,8 +3,7 @@ import torch
import torch.distributed as dist
import deepspeed
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.utils import groups
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
......@@ -15,7 +14,7 @@ PipeTopo = PipeDataParallelTopology
from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from util import required_torch_version
import argparse
......@@ -60,23 +59,17 @@ def compare_model_states(saved_model,
if not compare_optimizer:
return
if FP16_DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
if DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
saved_model.optimizer,
FP16_DeepSpeedZeroOptimizer_Stage3):
DeepSpeedZeroOptimizer_Stage3):
for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat):
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer):
elif isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
......@@ -444,8 +437,8 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
hidden_dim,
load_optimizer_states):
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
global DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
......@@ -525,8 +518,8 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
load_optimizer_states,
load_lr_scheduler_states):
if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
global DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册