未验证 提交 1d295ff5 编写于 作者: J Jeff Rasley 提交者: GitHub

Refactor ZeRO naming to reduce confusion (#1607)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 07887f66
...@@ -38,7 +38,7 @@ jobs: ...@@ -38,7 +38,7 @@ jobs:
run: | run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
nv-torch18-v100: nv-torch18-v100:
runs-on: [self-hosted, nvidia, torch18, v100] runs-on: [self-hosted, nvidia, torch18, v100]
...@@ -65,7 +65,7 @@ jobs: ...@@ -65,7 +65,7 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
nv-transformers-v100: nv-transformers-v100:
runs-on: [self-hosted, nvidia, torch18, v100] runs-on: [self-hosted, nvidia, torch18, v100]
...@@ -99,4 +99,4 @@ jobs: ...@@ -99,4 +99,4 @@ jobs:
pip install .[testing] pip install .[testing]
# find reqs used in ds integration tests # find reqs used in ds integration tests
find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec pip install -r {} \; find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec pip install -r {} \;
TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --durations=0 --verbose tests/deepspeed TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --color=yes --durations=0 --verbose tests/deepspeed
...@@ -22,8 +22,7 @@ from torch.distributed.distributed_c10d import _get_global_rank ...@@ -22,8 +22,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
from typing import Callable, Dict, Optional, Union, Iterable from typing import Callable, Dict, Optional, Union, Iterable
from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim from deepspeed.runtime.utils import see_memory_usage, get_ma_status, DummyOptim
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import ( from deepspeed.runtime.zero.utils import (
is_zero_supported_optimizer, is_zero_supported_optimizer,
...@@ -1326,28 +1325,12 @@ class DeepSpeedEngine(Module): ...@@ -1326,28 +1325,12 @@ class DeepSpeedEngine(Module):
if optimizer is None: if optimizer is None:
optimizer = DummyOptim(list(self.module.parameters())) optimizer = DummyOptim(list(self.module.parameters()))
if self.zero_legacy_stage1( if self.zero_legacy_stage1():
) and zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: raise Exception(
assert not self.has_moe_layers, "MoE not supported with Stage 1" "The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO."
assert not isinstance(optimizer, DummyOptim), "zero stage 1 requires an optimizer"
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
all_gather_partitions=self.zero_allgather_partitions(),
allgather_size=self.zero_allgather_bucket_size(),
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_predivide=self.gradient_predivide,
) )
elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
if zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
overlap_comm = self.zero_overlap_comm() overlap_comm = self.zero_overlap_comm()
contiguous_gradients = self.zero_contiguous_gradients() contiguous_gradients = self.zero_contiguous_gradients()
round_robin_gradients = self.zero_round_robin_gradients() round_robin_gradients = self.zero_round_robin_gradients()
...@@ -1366,7 +1349,7 @@ class DeepSpeedEngine(Module): ...@@ -1366,7 +1349,7 @@ class DeepSpeedEngine(Module):
) )
overlap_comm = False overlap_comm = False
optimizer = FP16_DeepSpeedZeroOptimizer( optimizer = DeepSpeedZeroOptimizer(
optimizer, optimizer,
timers=timers, timers=timers,
static_loss_scale=self.loss_scale(), static_loss_scale=self.loss_scale(),
...@@ -1399,9 +1382,9 @@ class DeepSpeedEngine(Module): ...@@ -1399,9 +1382,9 @@ class DeepSpeedEngine(Module):
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS: elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
assert not self.has_moe_layers, "MoE not supported with Stage 3" assert not self.has_moe_layers, "MoE not supported with Stage 3"
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3( optimizer = DeepSpeedZeroOptimizer_Stage3(
self.module, self.module,
optimizer, optimizer,
timers=timers, timers=timers,
......
import math
import torch
import torch.distributed as dist
from collections import defaultdict
from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
from deepspeed.utils import logger, log_dist
from deepspeed.ops.op_builder import UtilsBuilder
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
if sub_partition_high_limit <= flattened_lean_size:
return 0
else:
return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
group_paddings = []
flattened_size = sum([tensor.numel() for tensor in tensor_list])
for i in range(sub_partition_count):
padding = get_alignment_padding(flattened_size, i, sub_partition_size)
group_paddings.append(padding)
return group_paddings
def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
# Fully inside bounds
return True, offset
elif (start_index > current_index) and (start_index < (current_index + tensor_size)):
# Partially contained, compute offset
offset = start_index - current_index
return True, offset
else:
return False, offset
def _range_check(current_index, element_intervals, tensor_size):
results = []
for comm_idx, interval in enumerate(element_intervals):
start_index, end_index = interval
contained, offset = _single_range_check(current_index, start_index, end_index, tensor_size)
if contained:
results.append((contained, offset, comm_idx))
if len(results) == 0:
return [(False, 0, -1)]
return results
class FP16_DeepSpeedZeroOptimizer_Stage1(object):
"""
FP16_DeepSpeedZeroOptimizer_Stage1 designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
This version aligns with stage-1 in the paper above.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
dp_process_group=None,
partition_size=None,
mpu=None,
all_gather_partitions=True,
allgather_size=500000000,
clip_grad=0.0,
max_elements_per_comm=5e8,
elastic_checkpoint=True,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_average=True):
# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
"and partition size")
if dp_process_group is None:
dp_process_group = _initialize_parameter_parallel_groups(partition_size)
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.verbose = verbose
self.dp_process_group = dp_process_group
self.postscale_gradients = postscale_gradients
self.gradient_predivide_factor = gradient_predivide_factor
self.gradient_average = gradient_average
self._global_grad_norm = 0.
# TODO: automatically turn off if #params > some_limit
self.all_gather_partitions = all_gather_partitions
self.allgather_size = allgather_size
# self.max_elements_per_comm = max_elements_per_comm
# logger.info("max_elements_per_comm={}".format(max_elements_per_comm))
self.elastic_checkpoint = elastic_checkpoint
logger.info(f'ZeRO Elastic Checkpoint = {elastic_checkpoint}')
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
# Setup bookkeeping data structures depending on partitioning type
# parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
self.parallel_sub_partitioned_fp16_groups = []
# same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
self.parallel_comm_sub_partitioned_fp16_groups = []
# 32-bit sub-partitions of the parallel partitioned parameters
# that this process will update
self.local_sub_partitions_of_fp32_groups = []
# param partition info
# parameters in each group that will not be updated by this process directly
self.params_not_local = []
# parameters that will be updated by this process directly
self.params_in_rank_sub_partitions = []
# parameter offsets for parameters in sub-partitions. Parameter
# boundaries may not align with sub-partition boundaries
# so we need to keep track of the offsets
self.params_in_rank_sub_partitions_offsets = []
# number of elements per sub-partition in each group
self.sub_partition_sizes = []
# number of communication intervals for each group
self.num_comm_intervals_per_group = []
local_rank = dist.get_rank(group=self.dp_process_group)
self.group_paddings = []
self.partition_count = dist.get_world_size(group=self.dp_process_group)
self.default_device = self.optimizer.param_groups[0]['params'][0].device
# max elems per param group
self.max_elems_per_comm = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# calculate best max elements per comm based to minimize padding
self.max_elems_per_comm.append(
self.best_max_elems_per_comm(
num_elements=sum(t.numel() for t in self.fp16_groups[i]),
max_elements_per_comm=max_elements_per_comm,
dp=dist.get_world_size(group=self.dp_process_group)))
# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=self.fp16_groups[i],
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elems_per_comm[i],
pg=self.dp_process_group)
self.fp16_groups_flat.append(flat_aligned_params)
# TODO: I don't think this does anything?
# set model fp16 weight to slices of flattened buffer
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition
# RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=self.fp16_groups_flat[i],
max_elements_per_comm=self.max_elems_per_comm[i],
world_size=dist.get_world_size(
group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
self.parallel_comm_sub_partitioned_fp16_groups.append(
comm_partitions) # comm -> rank
self.parallel_sub_partitioned_fp16_groups.append(
dp_sub_partitions) # rank -> comm
self.sub_partition_sizes.append(sub_partition_size)
self.num_comm_intervals_per_group.append(num_comm_intervals)
# data_parallel_partitions = self.get_data_parallel_partitions(self.fp16_groups_flat[i])
# self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)
# a partition of the fp32 master weights that will be updated by this process
# RS: store/detach/cast our local sub-partitions
local_sub_partitions = []
for sub_partition in self.parallel_sub_partitioned_fp16_groups[i][
local_rank]:
fp32_sub_partition = sub_partition.clone().float().detach()
fp32_sub_partition.requires_grad = True
local_sub_partitions.append(fp32_sub_partition)
self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions)
# Compute sub_partition paddings
sub_partition_paddings = get_group_alignment_padding(
tensor_list=self.fp16_groups[i],
sub_partition_size=sub_partition_size,
sub_partition_count=num_comm_intervals * self.partition_count)
self.group_paddings.append(sub_partition_paddings)
# modify optimizer of have flat master weight
# self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
tensor_list=self.fp16_groups[i],
all_element_intervals=element_intervals,
local_rank=local_rank,
world_size=dist.get_world_size(group=self.dp_process_group)
)
self.params_in_rank_sub_partitions.append(params_in_rank_sub_partition)
self.params_not_local.append(params_not_local)
self.params_in_rank_sub_partitions_offsets.append(
params_in_rank_sub_partitions_offsets)
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
self.mpu = mpu
self.clip_grad = clip_grad
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
zero_reduce_scatter=True)
self._initialize_optimizer_states()
def _initialize_optimizer_states(self):
for group_idx, group in enumerate(self.local_sub_partitions_of_fp32_groups):
for idx, sub_partition_param in enumerate(group):
sub_partition_grad = torch.zeros(int(
self.sub_partition_sizes[group_idx]),
dtype=sub_partition_param.dtype).cuda()
sub_partition_param.grad = sub_partition_grad
self.optimizer.step()
for group in self.local_sub_partitions_of_fp32_groups:
for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None
@staticmethod
def best_max_elems_per_comm(num_elements, max_elements_per_comm, dp):
# if we use max-elems-per-comm as is, how many comm intervals will there be
max_comm_intervals = math.ceil(num_elements / max_elements_per_comm)
padding_for_max_comm = (max_elements_per_comm *
max_comm_intervals) - num_elements
# if we use 1 less comm interval how much extra comm padding would be required
min_comm_intervals = num_elements // max_elements_per_comm
if min_comm_intervals == 0:
log_dist(f'Using default max_elements_per_comm {max_elements_per_comm}',
ranks=[0])
return max_elements_per_comm
padding_for_min_comm = math.ceil(num_elements / (dp * min_comm_intervals))
# choose padding that uses least amount of overhead
if padding_for_max_comm > padding_for_min_comm:
new_max_elements_per_comm = padding_for_min_comm + max_elements_per_comm
log_dist(
f'Updating max_elements_per_comm from {max_elements_per_comm} -> {new_max_elements_per_comm}',
ranks=[0])
return new_max_elements_per_comm
else:
log_dist(f'Using default max_elements_per_comm {max_elements_per_comm}',
ranks=[0])
return max_elements_per_comm
@staticmethod
def get_data_parallel_sub_partitions(tensor,
max_elements_per_comm,
world_size,
dp_process_group=None):
total_num_elements = tensor.numel()
# if total elements is less than our max, revert to splitting into dp partitions
max_elements_per_comm = min(total_num_elements, max_elements_per_comm)
sub_partition_size = int(max_elements_per_comm // world_size)
# Ensure partition alignment was done correctly
num_sub_partitions = int(total_num_elements // sub_partition_size)
assert total_num_elements % sub_partition_size == 0, "{} % {} != 0".format(total_num_elements, sub_partition_size)
# Ensure comm interval alignment was done correctly.
num_comm_intervals = int(num_sub_partitions // world_size)
assert num_sub_partitions % world_size == 0, "{} % {} != 0".format(num_sub_partitions, world_size)
if not dist.is_initialized() or dist.get_rank(group=dp_process_group) == 0:
logger.info("**** partition info:")
logger.info("\t total_num_elements=%s", total_num_elements)
logger.info("\t world_size=%s", world_size)
logger.info("\t max_elements_per_comm=%s", max_elements_per_comm)
logger.info("\t sub_partition_size=%s", sub_partition_size)
logger.info("\t num_sub_partitions=%s", num_sub_partitions)
logger.info("\t num_comm_intervals=%s", num_comm_intervals)
logger.info("****")
# [comm_id] -> [rank]
comm_partitions = []
for _ in range(num_comm_intervals):
comm_partitions.append([])
start = 0
comm_id = 0
element_intervals = defaultdict(
list) # [rank] -> [(start,end), (start,end), ...]
for idx in range(num_sub_partitions):
rank_id = idx % world_size
sub_partition = tensor.narrow(0, start, sub_partition_size).detach()
element_intervals[rank_id].append((start, start + sub_partition_size))
comm_partitions[comm_id].append(sub_partition)
start = start + sub_partition_size
if rank_id == (world_size - 1):
comm_id += 1
# [rank] -> [comm_id]
sub_partitions = []
for _ in range(world_size):
sub_partitions.append([])
for comm_id, partitions in enumerate(comm_partitions):
for rank_id, partition in enumerate(partitions):
sub_partitions[rank_id].append(partition)
return comm_partitions, sub_partitions, element_intervals, sub_partition_size, num_comm_intervals
@staticmethod
def get_all_sub_partition_info(tensor_list,
all_element_intervals,
local_rank,
world_size):
params_not_local = []
# [rank] -> [comm-id] -> [param/offset]
params_in_rank_sub_partition = []
params_in_rank_sub_partitions_offsets = []
for rank in range(world_size):
params_in_local_sub_partition = []
local_sub_partition_offsets = []
comm_tensor_list = []
comm_offset_list = []
current_index = 0
prev_comm_idx = 0
for iii, tensor in enumerate(tensor_list):
tensor_size = tensor.numel()
#if local_rank == 0:
# # logger.info("rank={}, current_index={}, tensor_size={}, tensor-idx={}".format(rank,
# current_index, tensor_size, iii))
results_list = _range_check(current_index,
all_element_intervals[rank],
tensor_size)
for contained, offset, comm_idx in results_list:
#if local_rank == 0:
# logger.info("rank={}, contained={}, offset={}, comm_idx={}".format(rank, contained,
# offset, comm_idx))
if contained:
if prev_comm_idx != comm_idx:
params_in_local_sub_partition.append(comm_tensor_list)
comm_tensor_list = []
local_sub_partition_offsets.append(comm_offset_list)
comm_offset_list = []
comm_tensor_list.append(tensor)
comm_offset_list.append(offset)
prev_comm_idx = comm_idx
elif rank == local_rank:
params_not_local.append(tensor)
current_index = current_index + tensor_size
#assert len(comm_tensor_list) > 0
#assert len(comm_offset_list) > 0
params_in_local_sub_partition.append(comm_tensor_list)
local_sub_partition_offsets.append(comm_offset_list)
params_in_rank_sub_partition.append(params_in_local_sub_partition)
params_in_rank_sub_partitions_offsets.append(local_sub_partition_offsets)
return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local
def get_flat_sub_partitions(self,
comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
default_device,
num_comm_intervals=None,
return_partition_params=False):
partition_params = []
final_param_offsets = []
flat_sub_partitions = []
for tensor_list, param_offsets in zip(comm_tensor_list, comm_param_offsets):
flat_tensor_list = []
current_size = 0
my_offsets = []
my_params = []
for i, tensor in enumerate(tensor_list):
if tensor.grad is None:
tensor.grad = torch.zeros(tensor.size(),
dtype=tensor.dtype,
device=tensor.device)
param = tensor
tensor = tensor.grad
num_elements = tensor.numel()
tensor_offset = 0
#we need to offset to get to the right element
if i == 0 and param_offsets[i] > 0:
tensor_offset = param_offsets[i]
num_elements = num_elements - tensor_offset
# We don't need all elements of the tensor if this tensor is
# larger than we have space for in our curr sub-partition
if num_elements > (sub_partition_size - current_size):
num_elements = sub_partition_size - current_size
#we need a narrow view of the tensor based on the tensor offset and number of elements that
#we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
0,
int(tensor_offset),
int(num_elements)).to(dtype))
else:
flat_tensor_list.append(tensor.to(dtype))
my_params.append(param)
#remember offset into partition and #elems for this tensor
my_offsets.append((current_size, num_elements))
current_size = current_size + num_elements
#this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < sub_partition_size:
my_offsets.append((None, None))
my_params.append(None)
if len(tensor_list) == 0:
assert default_device != None
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=default_device))
else:
flat_tensor_list.append(
torch.zeros(int(sub_partition_size - current_size),
dtype=dtype,
device=tensor_list[0].device))
partition_params.append(my_params) #flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(self.flatten(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
# logger.info("padding w. sub partitions to ensure uniform communication")
device = flat_sub_partitions[0].device
for _ in range(num_comm_intervals - len(flat_sub_partitions)):
flat_sub_partitions.append(
torch.zeros(int(sub_partition_size),
dtype=dtype,
device=device))
partition_params.append([None])
final_param_offsets.append([(None, None)])
if return_partition_params:
assert len(flat_sub_partitions) == len(partition_params)
assert len(partition_params) == len(final_param_offsets), "{} {}".format(len(partition_params), len(final_param_offsets))
return flat_sub_partitions, partition_params, final_param_offsets
return flat_sub_partitions
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def free_grad_in_param_list(self, param_list):
for p in param_list:
if isinstance(p, list):
for _p in p:
_p.grad = None
else:
p.grad = None
def flatten_dense_tensors_sub_partition_aligned(self,
tensor_list,
dp,
max_elements_per_comm,
pg):
assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}"
num_elements = sum(t.numel() for t in tensor_list)
log_dist(
"Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])
# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)
# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)
if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size
# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size * dp) - num_elements
log_dist(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}",
ranks=[0])
log_dist(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}",
ranks=[0])
if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]
flat_tensors = self.flatten(aligned_tensor_list)
return flat_tensors
def reduce_gradients(self, pipeline_parallel=False):
postscale_gradients = self.postscale_gradients
gradient_predivide_factor = self.gradient_predivide_factor
gradient_average = self.gradient_average
world_size = dist.get_world_size(group=self.dp_process_group)
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
num_comm_intervals = self.num_comm_intervals_per_group[i]
all_sub_partitions = []
for rank in range(world_size):
# gsp is list of partitions indexed by comm_idx
grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][rank],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
[rank],
dtype=torch.half,
default_device=self.default_device,
sub_partition_size=self.sub_partition_sizes[i],
num_comm_intervals=self.num_comm_intervals_per_group[i])
all_sub_partitions.append(grad_sub_partitions)
assert len(grad_sub_partitions) == num_comm_intervals
local_comm_partitions = []
for comm_idx in range(num_comm_intervals):
single_comm_all_partitions = []
for rank in range(world_size):
single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx])
if postscale_gradients:
if gradient_predivide_factor != 1.0:
for partition in single_comm_all_partitions:
partition.mul_(1. / gradient_predivide_factor)
dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
input_list=single_comm_all_partitions,
group=self.dp_process_group)
if gradient_average:
# Only need to average our local grads in post scaling
if gradient_predivide_factor != world_size:
single_comm_all_partitions[local_rank].mul_(
gradient_predivide_factor / world_size)
else:
for partition in single_comm_all_partitions:
partition.div_(world_size)
dist.reduce_scatter(output=single_comm_all_partitions[local_rank],
input_list=single_comm_all_partitions,
group=self.dp_process_group)
def step(self, closure=None):
# First compute norm for all group so we know if there is overflow
self.overflow = self.overflow_checker.check()
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self.zero_grad()
if self.verbose:
logger.info(
"[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.loss_scale))
return self.overflow
norm_groups = []
local_sub_partitions_grad_groups = []
partition_id = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(self.fp16_groups):
#TODO RS: update get grad norm to support sub partitions
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
#RS: update free grads w.r.t. sub partitions
#free gradients for all the parameters that are not updated by this process
self.free_grad_in_param_list(self.params_not_local[i])
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
[partition_id],
sub_partition_size=self.sub_partition_sizes[i],
dtype=self.local_sub_partitions_of_fp32_groups[i][0].dtype,
num_comm_intervals=self.num_comm_intervals_per_group[i],
default_device=self.default_device)
#RS: update all our local params with sub-partition grads
for idx, sub_partition_param in enumerate(self.local_sub_partitions_of_fp32_groups[i]):
sub_partition_param.grad = local_grad_sub_partitions[idx]
#RS: update free grads for sub-partitions
#release all the gradient since we have already created a necessary copy in dp_grad_partition
self.free_grad_in_param_list(
self.params_in_rank_sub_partitions[i][partition_id])
local_sub_partitions_grad_groups.append(local_grad_sub_partitions)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
#RS: update unscale/clip with sub partitions
self.unscale_and_clip_grads(local_sub_partitions_grad_groups,
self._global_grad_norm)
self.optimizer.step()
#RS: clear our sub partition grads
#get rid of the fp32 gradients. Not needed anymore
for group in self.local_sub_partitions_of_fp32_groups:
for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None
#group.grad = None
#NOTE RS: removed norm_groups outer loop from original code, i don't think it's needed
#RS: copy all sub-partition fp32 data to fp16 sub partitions
# copy fp32 param data to fp16 partitions w.r.t. our local rank
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp16.data.copy_(
local_sub_partition_param_fp32.data)
#RS: all_gather/broadcast sub-partitions in separate comm calls
#gather the updated weights from everyone
for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups:
for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions):
dist.all_gather(sub_partitions,
sub_partitions[partition_id],
group=self.dp_process_group)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
return self.overflow
def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):
sub_partitions = grad
for g in sub_partitions:
g.data.mul_(1. / combined_scale)
else:
grad.data.mul_(1. / combined_scale)
def backward(self, loss, retain_graph=False):
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return communication interval paddings for local rank and group
def _get_local_group_paddings(self, group_index):
local_rank = dist.get_rank(group=self.dp_process_group)
sub_partition_indices = [
local_rank + (comm_idx * self.partition_count)
for comm_idx in range(self.num_comm_intervals_per_group[group_index])
]
group_paddings = [
self.group_paddings[group_index][sub_idx]
for sub_idx in sub_partition_indices
]
return group_paddings
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains sub partitions.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for group_index, group in enumerate(groups_with_padding):
group_paddings = self._get_local_group_paddings(group_index)
lean_sub_partitions = []
for sub_partition, padding in zip(group, group_paddings):
lean_length = sub_partition.numel() - padding
lean_sub_partitions.append(sub_partition[:lean_length])
groups_without_padding.append(lean_sub_partitions)
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for group_index, group in enumerate(self.optimizer.param_groups):
param_paddings = self._get_local_group_paddings(group_index)
group_lean_state = []
for param_idx, param in enumerate(group['params']):
lean_state = self._get_state_without_padding(self.optimizer.state[param],
param_paddings[param_idx])
group_lean_state.append(lean_state)
optimizer_groups_state.append(group_lean_state)
return optimizer_groups_state
def _rigid_state_dict(self):
"""
Returns a dict that can be loaded for continued training with same DP degree
"""
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self.optimizer.state_dict()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict
def _elastic_state_dict(self):
"""
Returns a dict that can be loaded for elastic training with different DP degree
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.local_sub_partitions_of_fp32_groups)
state_dict['local_sub_partitions_of_fp32_groups'] = fp32_groups_without_padding
return state_dict
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
if self.elastic_checkpoint:
return self._elastic_state_dict()
return self._rigid_state_dict()
# Extract the fp32 weights of the current rank from checkpoint by merging the
# sub partitions of communication intervals across ranks.
# Let sub_i_j = sub partition of rank i and comm interval j
# For 2 ranks and 2 comm intervals, checkpoints (minus padding) are as follows:
# rank 0 = [sub_0_0, sub_0_1]
# rank 1 = [sub_1_0, sub_1_1]
# Merge to get [sub_0_0, sub_1_0, sub_0_1, sub_1_1] => original un-padded flattened tensor.
def _retrieve_group_sub_partition_weights(self,
all_partition_fp32_weights,
max_elems_per_comm):
num_partitions = len(all_partition_fp32_weights)
num_comm_intervals = len(all_partition_fp32_weights[0])
num_sub_partitions = num_partitions * num_comm_intervals
all_sub_partition_weights = [None] * num_sub_partitions
for rank, partition_weights in enumerate(all_partition_fp32_weights):
for comm_idx, sub_partition_weights in enumerate(partition_weights):
#all_sub_partition_weights.append(sub_partition_weights)
sub_partition_idx = (comm_idx * num_partitions) + rank
all_sub_partition_weights[sub_partition_idx] = sub_partition_weights
flat_merged_weights = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=max_elems_per_comm,
pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=flat_merged_weights,
max_elements_per_comm=max_elems_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
sub_partition_of_fp32_groups = []
for group_idx in range(len(self.local_sub_partitions_of_fp32_groups)):
all_partition_fp32_weights = [
sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict
]
max_elems_per_comm = self.max_elems_per_comm[group_idx]
sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights,
max_elems_per_comm)
sub_partition_of_fp32_groups.append(sub_partition_weights)
for current_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, sub_partition_of_fp32_groups):
for current_sub_part, saved_sub_part in zip(current_group, saved_group):
current_sub_part.data.copy_(saved_sub_part.data)
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self,
state_key,
all_partition_states,
max_elems_per_comm):
if not torch.is_tensor(all_partition_states[0]):
return all_partition_states[0]
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=max_elems_per_comm,
pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=flat_merged_partitions,
max_elements_per_comm=max_elems_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Compute the optimizer state partitions for the group by
# 1) Merging state values across the previous partitioning.
# 2) Repartition state values for the new partitioning
# 3) Return state corresponding to local partition
def _retrieve_group_optimizer_states(self, all_partition_states, max_elems_per_comm):
merged_optimizer_states = {}
num_partitions = len(all_partition_states)
num_comm_intervals = len(all_partition_states[0])
num_sub_partitions = num_partitions * num_comm_intervals
for rank, partition_state in enumerate(all_partition_states):
for comm_idx, sub_partition_state in enumerate(partition_state):
for key, value in sub_partition_state.items():
if not key in merged_optimizer_states.keys():
merged_optimizer_states[key] = [None] * num_sub_partitions
sub_partition_idx = (comm_idx * num_partitions) + rank
merged_optimizer_states[key][sub_partition_idx] = value
group_optimizer_states = {}
for key, value in merged_optimizer_states.items():
group_optimizer_states[key] = self._partition_base_optimizer_state(
key,
value,
max_elems_per_comm)
return group_optimizer_states
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, state_dict_list):
base_optimizer_group_states = []
for group_idx in range(len(self.optimizer.param_groups)):
all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list
]
max_elems_per_comm = self.max_elems_per_comm[group_idx]
group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states,
max_elems_per_comm)
base_optimizer_group_states.append(group_optimizer_states)
for group_idx, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']):
for key, saved in base_optimizer_group_states[group_idx].items():
if torch.is_tensor(self.optimizer.state[param][key]):
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
else:
self.optimizer.state[param][key] = saved
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for fp16_sub_partition, fp32_sub_partition in zip(fp16_partitions[partition_id], fp32_partitions):
fp32_sub_partition.data.copy_(fp16_sub_partition.data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['base_optimizer_state'])
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']):
for curr_param, saved_param in zip(curr_group, saved_group):
curr_param.data.copy_(saved_param.data)
def _elastic_load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states:
self._restore_base_optimizer_state(state_dict_list)
if load_from_fp32_weights:
self._restore_from_fp32_weights(state_dict_list)
else:
self._restore_from_fp16_weights()
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
if self.elastic_checkpoint:
self._elastic_load_state_dict(state_dict_list,
load_optimizer_states,
load_from_fp32_weights)
else:
self._rigid_load_state_dict(
state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states)
def _dump_optimizer_state(self, message):
logger.info(f'{message}')
for i, group in enumerate(self.optimizer.param_groups):
for j, param in enumerate(group['params']):
for key, value in self.optimizer.state[param].items():
t_stats = [
value.min(),
value.max(),
(value.max() - value.min()),
value.mean()
]
stats = [float(t) for t in t_stats]
logger.info(
f'group/param/key/min/max/delta/mean = {i}, {j}, {key}: {stats}')
...@@ -592,7 +592,7 @@ class PostBackwardFunction(torch.autograd.Function): ...@@ -592,7 +592,7 @@ class PostBackwardFunction(torch.autograd.Function):
INITIAL_MICRO_STEP_ID = -1 INITIAL_MICRO_STEP_ID = -1
class FP16_DeepSpeedZeroOptimizer_Stage3(object): class DeepSpeedZeroOptimizer_Stage3(object):
""" """
DeepSpeedZeroOptimizer designed to reduce the memory footprint DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models. required for training large deep learning models.
......
...@@ -70,7 +70,7 @@ def print_rank_msg(msg): ...@@ -70,7 +70,7 @@ def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}") print(f"rank {dist.get_rank()} - {msg}")
class FP16_DeepSpeedZeroOptimizer(object): class DeepSpeedZeroOptimizer(object):
""" """
DeepSpeedZeroOptimizer designed to reduce the memory footprint DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models. required for training large deep learning models.
...@@ -2135,8 +2135,7 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -2135,8 +2135,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
ckpt_version = state_dict_list[0].get("ds_version", False) ckpt_version = state_dict_list[0].get("ds_version", False)
error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \ error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
"with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \ "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
"please set 'legacy_stage1': true in your zero config json. This old version of " \ "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json."
"stage 1 will be removed in v0.4.0."
assert ckpt_version, f"Empty ds_version! {error_str}" assert ckpt_version, f"Empty ds_version! {error_str}"
assert required_version <= pkg_version.parse(ckpt_version), f"Old version: {ckpt_version} {error_str}" assert required_version <= pkg_version.parse(ckpt_version), f"Old version: {ckpt_version} {error_str}"
......
...@@ -241,7 +241,7 @@ Example of <i>**scheduler**</i> ...@@ -241,7 +241,7 @@ Example of <i>**scheduler**</i>
**Note:** this mode cannot be combined with the `fp16` mode described above. **Note:** this mode cannot be combined with the `fp16` mode described above.
{: .notice--warning} {: .notice--warning}
**Note:** this mode is only compatible with ZeRO stage 2. **Note:** this mode is only compatible with ZeRO stages 1 and 2.
{: .notice--warning} {: .notice--warning}
<i>**bfloat16**</i>: [dictionary] <i>**bfloat16**</i>: [dictionary]
......
...@@ -3,8 +3,7 @@ import torch ...@@ -3,8 +3,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import deepspeed import deepspeed
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.utils import groups from deepspeed.utils import groups
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
...@@ -15,7 +14,7 @@ PipeTopo = PipeDataParallelTopology ...@@ -15,7 +14,7 @@ PipeTopo = PipeDataParallelTopology
from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from util import required_torch_version from util import required_torch_version
import argparse import argparse
...@@ -60,23 +59,17 @@ def compare_model_states(saved_model, ...@@ -60,23 +59,17 @@ def compare_model_states(saved_model,
if not compare_optimizer: if not compare_optimizer:
return return
if FP16_DeepSpeedZeroOptimizer_Stage3 is not None and isinstance( if DeepSpeedZeroOptimizer_Stage3 is not None and isinstance(
saved_model.optimizer, saved_model.optimizer,
FP16_DeepSpeedZeroOptimizer_Stage3): DeepSpeedZeroOptimizer_Stage3):
for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat): for p0, p1 in zip(saved_model.optimizer.fp32_partitioned_groups_flat, loaded_model.optimizer.fp32_partitioned_groups_flat):
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer): elif isinstance(saved_model.optimizer, DeepSpeedZeroOptimizer):
for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups): for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}' assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1):
for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups):
for p0, p1 in zip(partition0, partition1):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}"
elif isinstance(saved_model.optimizer, FP16_Optimizer): elif isinstance(saved_model.optimizer, FP16_Optimizer):
for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat): for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat):
assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}' assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}'
...@@ -444,8 +437,8 @@ def test_checkpoint_zero_no_optimizer(tmpdir, ...@@ -444,8 +437,8 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
hidden_dim, hidden_dim,
load_optimizer_states): load_optimizer_states):
if zero_stage == 3: if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3 global DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
with deepspeed.zero.Init(): with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else: else:
...@@ -525,8 +518,8 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim ...@@ -525,8 +518,8 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
load_optimizer_states, load_optimizer_states,
load_lr_scheduler_states): load_lr_scheduler_states):
if zero_stage == 3: if zero_stage == 3:
global FP16_DeepSpeedZeroOptimizer_Stage3 global DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
with deepspeed.zero.Init(): with deepspeed.zero.Init():
models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册