未验证 提交 2e99f6ed 编写于 作者: Z Zhen Zhang 提交者: GitHub

[DRAFT] Tentative implementation of MiCS (#2964)

* include mics config and optimizer

* change private vars to public vars

so the child class can initialize these vars

* Port the init function from stage3

* adding a model test file for mics

* adopt to get_acceleartor api and fp16 group defrag

* WIP: porting mics modification to ms master

* WIP: included gradient all-reduce among replication groups

* WIP: ported hierarchical all gather part

did basic loss test on a simple MLP model

* [Bug fix] using the comm group attached on the param

* torch2.0 support

* remove print

* delegate wait op

* [Bug] fix naming

* adding doc string

* resolving recursive import

* fix formating, typo and license

* fix license and unit test error

---------
Co-authored-by: NUbuntu <ubuntu@ip-172-31-14-191.us-west-2.compute.internal>
Co-authored-by: NUbuntu <ubuntu@ip-172-31-7-70.us-west-2.compute.internal>
Co-authored-by: NZhen Zhang <zhzhn@amazon.com>
Co-authored-by: Nzhzhn <zhzhn@ip-10-2-57-114.us-west-2.compute.internal>
上级 d9253950
......@@ -440,6 +440,30 @@ def reduce_scatter(output,
return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
def has_all_reduce_coalesced():
""""""
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_all_reduce_coalesced is not None, 'has_all_reduce_coalesced is not yet defined'
return cdb.has_all_reduce_coalesced
def has_coalescing_manager():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_coalescing_manager is not None, 'has_coalescing_manager is not yet defined'
return cdb.has_coalescing_manager
def all_gather_coalesced(output_tensors, input_tensors, group=None, async_op=False):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.all_gather_coalesced(output_tensors, input_tensors, group=group, async_op=async_op)
@timed_op
def all_reduce(tensor,
op=ReduceOp.SUM,
......@@ -457,6 +481,18 @@ def all_reduce(tensor,
return cdb.all_reduce(tensor, op, group, async_op)
@timed_op
def all_reduce_coalesced(tensors,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='all_reduce',
debug=get_caller_func()):
global cbd
return cdb.all_reduce_coalesced(tensors, op, group, async_op)
def get_world_group():
global cdb
assert cdb is not None and cdb.is_initialized(
......
......@@ -9,6 +9,32 @@ from .utils import *
from .backend import *
from .comm import *
import torch
def is_torch_two():
TORCH_MAJOR = int(torch.__version__.split('.')[0])
if TORCH_MAJOR >= 2:
return True
else:
return False
def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
def has_all_reduce_coalesced():
return hasattr(torch.distributed, "all_reduce_coalesced")
def get_coalescing_manager(group, device, reqs):
if is_torch_two():
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
else:
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
class TorchBackend(Backend):
"""
......@@ -21,6 +47,8 @@ class TorchBackend(Backend):
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_coalescing_manager = has_coalescing_manager()
self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True
......@@ -66,6 +94,16 @@ class TorchBackend(Backend):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
"""
if not self.has_all_reduce_coalesced:
raise RuntimeError(f"Current torch version does not have all_reduce_coalesced "
f"api (torch.__version__: {torch.__version__})")
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
......@@ -89,11 +127,34 @@ class TorchBackend(Backend):
group=group,
async_op=async_op)
else:
utils.logger.warning("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
if hasattr(torch.distributed.distributed_c10d, '_all_gather_base_coalesced'):
# customized PyTorch
return torch.distributed.distributed_c10d._all_gather_base_coalesced(output_tensors,
input_tensors,
group=group,
async_op=async_op)
elif has_coalescing_manager():
reqs = []
with get_coalescing_manager(group, input_tensors[0].device, reqs):
for output, input in zip(output_tensors, input_tensors):
handle = torch.distributed.distributed_c10d.all_gather_into_tensor(output,
input,
group=group,
async_op=True)
reqs.append(handle)
if async_op:
return reqs[-1]
else:
reqs[-1].wait()
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
......
......@@ -779,6 +779,8 @@ class DeepSpeedConfig(object):
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.zero_config = get_zero_config(param_dict)
self.mics_shard_size = self.zero_config.mics_shard_size
self.mics_hierarchial_params_gather = self.zero_config.mics_hierarchical_params_gather
self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0
......
......@@ -691,6 +691,9 @@ class DeepSpeedEngine(Module):
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
def mics_shard_size(self):
return self._config.mics_shard_size
def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
......@@ -1368,6 +1371,8 @@ class DeepSpeedEngine(Module):
def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
mics_shard_size = self.mics_shard_size()
model_dtype, grad_accum_dtype = self.get_data_types()
timers = self.timers if self.wall_clock_breakdown() else None
......@@ -1443,6 +1448,14 @@ class DeepSpeedEngine(Module):
offload_param_config=self.zero_offload_param(),
mpu=self.mpu)
else:
log_dist(
f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
f' MiCS is enabled {mics_shard_size>0},'
f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
ranks=[0])
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3(
......@@ -1479,6 +1492,37 @@ class DeepSpeedEngine(Module):
return optimizer
def _return_mics_optimizer(self, basic_optimizer, timers):
from deepspeed.runtime.zero.mics import MiCS_Optimizer
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
communication_data_type=self.communication_data_type)
return optimizer
def _configure_eigenvalue(self):
eigenvalue = Eigenvalue(
verbose=self.eigenvalue_verbose(),
......
......@@ -11,3 +11,5 @@ from .partition_parameters import register_external_parameter
from .tiling import TiledLinear
from .tiling import TiledLinearReturnBias
from .mics import MiCS_Init
......@@ -249,6 +249,9 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
between optimizer steps) or GPU count (increased parallelism).
"""
mics_shard_size: int = Field(-1, new_param="mics_shard_size")
mics_hierarchical_params_gather: bool = False
memory_efficient_linear: bool = True
"""
Use memory efficient linear implementation, for Stage 3.
......
此差异已折叠。
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from dataclasses import dataclass
from typing import List
import numpy as np
import torch
from torch import Tensor
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import logger
def _log_rank0(msg):
if dist.get_rank() == 0:
logger.info(msg)
@torch.jit.script
def scale_tensors(tensors: List[Tensor], scale: int):
for t in tensors:
t.div_(scale)
@dataclass
class MiCS_CommGroups:
""""""
param_shard_group = None
param_shard_size = -1
param_shard_rank = -1
param_repli_group = None
param_repli_size = -1
param_repli_rank = -1
param_intra_node_group = None
param_inter_node_shard_group = None
def create_mics_comm_groups(
shard_size,
dp_group,
hierarchical_allgather=False,
mpu=None,
):
"""
create shard-group, replicate-group from config_file
TODO: consider broadcast the config from rank0
Returns:
MiCS_CommGroups
"""
# env var for debugging purpose
ndevices_per_node = int(os.environ.get("NDEV_PER_NODE", get_accelerator().device_count()))
_log_rank0(f'creating MiCS communication groups with per node device size {ndevices_per_node}')
groups = MiCS_CommGroups()
if mpu is not None:
assert dp_group == mpu.get_data_parallel_group()
# full size of the world
world_size = dist.get_world_size()
# global rank
global_rank = dist.get_rank()
config = _generate_mics_config(world_size, ndevices_per_node, shard_size, 1)
ranks_of_shard_group = config['shard_groups']
ranks_of_repli_group = config['replicate_groups']
if len(ranks_of_repli_group) == 0:
assert len(ranks_of_shard_group) == 1, "replicate groups are empty only for single shard group"
for r in ranks_of_shard_group[0]:
ranks_of_repli_group.append([r])
# for simplicity
assert _sizes_all_same(ranks_of_repli_group), "replicate groups must have the same size"
assert _sizes_all_same(ranks_of_shard_group), "shard groups must have the same size"
assert sum([len(g) for g in ranks_of_shard_group]) == dist.get_world_size(), "all sharded ranks "
if len(ranks_of_shard_group) > 1: # if only shard on one group then no need for replicate groups
assert len(ranks_of_shard_group) == len(
ranks_of_repli_group[0]), "number of shard groups must equal to the size of each replicate group"
global_rank = dist.get_rank()
# create shard groups
for shard_ranks in ranks_of_shard_group:
_group = dist.new_group(shard_ranks)
if global_rank in shard_ranks:
groups.param_shard_group = _group
groups.param_shard_size = len(shard_ranks)
groups.param_shard_rank = dist.get_rank(_group)
logger.info(f'rank {global_rank}, shard group'
f' {groups.param_shard_rank}/{dist.get_world_size(group=_group)}')
# create replicate groups
for repli_ranks in ranks_of_repli_group:
if len(repli_ranks) > 1:
_group = dist.new_group(repli_ranks)
if global_rank in repli_ranks:
groups.param_repli_group = _group
groups.param_repli_size = len(repli_ranks)
groups.param_repli_rank = dist.get_rank(group=_group)
logger.info(f'rank {global_rank} '
f'replicate group {groups.param_repli_rank}/{dist.get_world_size(group=_group)}')
else:
groups.param_repli_group = None
groups.param_repli_size = 1
groups.param_repli_rank = 0
logger.info(f'rank {global_rank} replicate group 0/1')
# assign shard group size as world size
assert groups.param_shard_size == len(ranks_of_shard_group[0])
if hierarchical_allgather:
# create hierarchy inter-node, intra-node groups
# n_span_nodes = config['shard_span']
n_span_nodes = config['span_nodes']
assert n_span_nodes > 1, "sharding spans on single node, no need for hierarchy allgather"
assert len(ranks_of_shard_group[0]) % n_span_nodes == 0
n_gpu_per_node = len(ranks_of_shard_group[0]) // n_span_nodes
intra_node_ranks_group = []
inter_node_ranks_group = []
for shard_group in ranks_of_shard_group:
_intra_node_ranks = []
for i in range(0, len(shard_group), n_gpu_per_node):
_intra_node_ranks.append(shard_group[i:i + n_gpu_per_node])
_inter_node_ranks = []
for i in range(n_gpu_per_node):
_ranks = [_g[i] for _g in _intra_node_ranks]
_inter_node_ranks.append(_ranks)
intra_node_ranks_group.append(_intra_node_ranks)
inter_node_ranks_group.append(_inter_node_ranks)
_log_rank0(f"create for hierarchy all-gather groups: intra nodes {intra_node_ranks_group}")
_log_rank0(f"create for hierarchy all-gather groups: inter nodes {inter_node_ranks_group}")
# create communicators
for shard_group in intra_node_ranks_group:
for intra_node_ranks in shard_group:
_group = dist.new_group(intra_node_ranks)
if global_rank in intra_node_ranks:
groups.param_intra_node_group = _group
_log_rank0(f'create group for intra node ranks {intra_node_ranks}')
for shard_group in inter_node_ranks_group:
for inter_node_ranks in shard_group:
_group = dist.new_group(inter_node_ranks)
if global_rank in inter_node_ranks:
groups.param_inter_node_shard_group = _group
_log_rank0(f'create group for inter node ranks {inter_node_ranks}')
return groups
def _generate_mics_config(world_size, ndev_per_node, shard_size, pp_size=1):
"""Generating the configuration for sharding This shard config generation assume
that the pipeline stages are partitioned in order, i.e., first ranks
hold the stage0, etc.
Args:
shard_size (int): zero3 data-parallel shard size, FIXME:
change the name later
pp_size (int): pipeline parallel size, currently, only work with
pipeline parallelism + zero
"""
assert world_size % pp_size == 0
assert (world_size // pp_size) % shard_size == 0, \
f"dp group size is not dividable by dp_shard_size, "\
f" (world_size {world_size}, pp_size {pp_size}, dp_shard_size {shard_size})"
config = {}
shard_groups = np.arange(world_size).reshape(-1, shard_size)
replicate_groups = []
for i in range(shard_size):
same_shard_ranks = shard_groups[:, i].tolist()
n_ranks = len(same_shard_ranks)
replicate_size = n_ranks // pp_size
replicate_groups.extend([same_shard_ranks[j:j + replicate_size] for j in range(0, n_ranks, replicate_size)])
config['replicate_groups'] = replicate_groups
config['shard_groups'] = shard_groups.tolist()
config["span_nodes"] = len(shard_groups[0]) // ndev_per_node
return config
def _sizes_all_same(groups):
"""all groups have same length"""
all_same = True
for g in groups:
if len(g) != len(groups[0]):
return False
return all_same
......@@ -528,12 +528,12 @@ class AllGatherHandle:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to be available")
self.__handle = handle
self.__param = param
self.handle = handle
self.param = param
def wait(self) -> None:
instrument_w_nvtx(self.__handle.wait)()
self.__param.ds_status = ZeroParamStatus.AVAILABLE
instrument_w_nvtx(self.handle.wait)()
self.param.ds_status = ZeroParamStatus.AVAILABLE
class AllGatherCoalescedHandle:
......@@ -545,32 +545,34 @@ class AllGatherCoalescedHandle:
partitions: List[Tensor],
world_size: int,
) -> None:
self.__allgather_handle = allgather_handle
self.__params = params
self.__partitions = partitions
self.__world_size = world_size
self.__complete = False
for param in self.__params:
# renaming the fields without double underscore to ease
# the class inheritance
self.allgather_handle = allgather_handle
self.params = params
self.partitions = partitions
self.world_size = world_size
self.complete = False
for param in self.params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")
@instrument_w_nvtx
def wait(self) -> None:
if self.__complete:
if self.complete:
return
instrument_w_nvtx(self.__allgather_handle.wait)()
instrument_w_nvtx(self.allgather_handle.wait)()
# split the single tensor out into individual tensors
param_offset = 0
for param in self.__params:
for param in self.params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
partitions: List[Tensor] = []
for rank in range(self.__world_size):
for rank in range(self.world_size):
param_start = rank * param.ds_tensor.ds_numel
if param_start < param.ds_numel:
part_to_copy = self.__partitions[rank].narrow(
part_to_copy = self.partitions[rank].narrow(
0, param_offset, min(param.ds_numel - param_start, param.ds_tensor.ds_numel))
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
......@@ -581,7 +583,7 @@ class AllGatherCoalescedHandle:
param_offset += param.ds_tensor.ds_numel
self.__complete = True
self.complete = True
def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle:
......@@ -733,7 +735,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
self.ds_process_group = data_parallel_group
self.rank = dist.get_rank(group=self.ds_process_group)
self.world_size = dist.get_world_size(group=self.ds_process_group)
self.dp_world_size = dist.get_world_size(group=self.ds_process_group)
# Local device is the device where the parameters are consumed, must be default device.
# It is the device where parameters are fully instantiated using allgather
......@@ -773,7 +775,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
def _update_persist_config(self, ds_config):
Init.apply_param_persistence = True
Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.world_size
Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions
def _convert_to_zero_parameters(self, param_list):
for param in param_list:
......@@ -811,7 +813,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}")
if get_accelerator().on_accelerator(param):
dist.broadcast(param, 0, self.ds_process_group)
dist.broadcast(param, 0, self.get_dp_process_group())
else:
if dist.get_rank() == 0:
logger.warn(f"param `{name}` in {module.__class__.__name__} "
......@@ -876,7 +878,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# fetches from nvme if the partition is not available and in nvme
self._ensure_availability_of_partitioned_params(params)
if self.world_size == 1:
if self.num_partitions == 1:
return _no_gather_coalesced(params)
for param in params:
......@@ -907,34 +909,35 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# have an opportunity to avoid some intermediate memory allocations
param, = params
param_buffer = torch.empty(
math.ceil(param.ds_numel / self.world_size) * self.world_size,
math.ceil(param.ds_numel / self.num_partitions) * self.num_partitions,
dtype=param.dtype,
device=get_accelerator().current_device_name(),
requires_grad=False,
)
handle = _dist_allgather_fn(param.ds_tensor.to(get_accelerator().current_device_name()), param_buffer,
self.ds_process_group)
self.get_partition_dp_group(param))
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
return AllGatherHandle(handle, param)
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)
flat_tensor = torch.empty(partition_sz * self.world_size,
flat_tensor = torch.empty(partition_sz * self.num_partitions,
dtype=get_only_unique_item(p.dtype for p in params),
device=get_accelerator().current_device_name(),
requires_grad=False)
partitions: List[Parameter] = []
for i in range(self.world_size):
for i in range(self.num_partitions):
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[self.rank])
handle = _dist_allgather_fn(partitions[self.rank], flat_tensor, self.ds_process_group)
out=partitions[self.get_partition_rank()])
handle = _dist_allgather_fn(partitions[self.get_partition_rank()], flat_tensor,
self.get_partition_dp_group(params[0]))
return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=partitions,
world_size=self.world_size,
world_size=self.num_partitions,
)
def partition(param_list=None, hierarchy=0, has_been_updated=False):
......@@ -989,6 +992,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
"grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None,
"persist": slf.ds_persist,
"active_sub_modules": slf.ds_active_sub_modules,
"ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None
}
def convert_to_zero_parameters(param_list):
......@@ -1025,8 +1029,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
return param.ds_numel + self._padding_size(param)
def _padding_size(self, param):
remainder = param.ds_numel % self.world_size
return (self.world_size - remainder) if remainder else 0
remainder = param.ds_numel % self.num_partitions
return (self.num_partitions - remainder) if remainder else 0
def _partition_numel(self, param):
return param.ds_tensor.ds_numel
......@@ -1122,7 +1126,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
return
tensor_size = self._aligned_size(param)
partition_size = tensor_size // self.world_size
partition_size = tensor_size // self.num_partitions
if param.ds_tensor is None:
final_location = None
......@@ -1153,7 +1157,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
param.ds_tensor.final_location = final_location
start = partition_size * self.rank
start = partition_size * self.get_partition_rank()
end = start + partition_size
one_dim_param = param.contiguous().view(-1)
......@@ -1207,7 +1211,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
partition_size = param.ds_tensor.ds_numel
tensor_size = partition_size * self.world_size
tensor_size = partition_size * self.num_partitions
aligned_param_size = self._aligned_size(param)
assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'
......@@ -1235,20 +1239,22 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# param.data = replicated_tensor.data
# return None
if self.use_all_gather_into_tensor:
# try the all_gather_into_tensor on PyTorch master branch
handle = dist.all_gather_into_tensor(flat_tensor,
param.ds_tensor.to(get_accelerator().device_name()),
group=self.ds_process_group,
group=self.get_partition_dp_group(param),
async_op=async_op)
else:
partitions = []
for i in range(self.world_size):
for i in range(self.num_partitions):
partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))
if i == dist.get_rank(group=self.ds_process_group):
if i == dist.get_rank(group=self.get_partition_dp_group(param)):
partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
handle = dist.all_gather(partitions, partitions[self.rank], group=self.ds_process_group, async_op=async_op)
handle = dist.all_gather(partitions,
partitions[self.get_partition_rank()],
group=self.get_partition_dp_group(param),
async_op=async_op)
replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
param.data = replicated_tensor.data
......@@ -1261,7 +1267,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
if len(param_list) == 0:
return
if self.world_size == 1:
if self.num_partitions == 1:
handle = _no_gather_coalesced(param_list)
handle.wait()
return None
......@@ -1276,27 +1282,25 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# allocate memory for allgather params
allgather_params = []
for psize in partition_sizes:
tensor_size = psize * self.world_size
tensor_size = psize * self.num_partitions
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device).view(-1)
flat_tensor.requires_grad = False
allgather_params.append(flat_tensor)
# launch
launch_handles = []
# backend = get_backend(self.ds_process_group)
# with _batch_p2p_manager(backend):
for param_idx, param in enumerate(param_list):
input_tensor = local_tensors[param_idx].view(-1)
if self.use_all_gather_into_tensor:
# try the all_gather_into_tensor from Pytorch master
# try the _all_gather_base from Pytorch master
h = dist.all_gather_into_tensor(allgather_params[param_idx],
input_tensor,
group=self.ds_process_group,
group=self.get_partition_dp_group(param),
async_op=True)
else:
output_list = []
for i in range(self.world_size):
for i in range(self.num_partitions):
psize = partition_sizes[param_idx]
partition = allgather_params[param_idx].narrow(0, i * psize, psize)
output_list.append(partition)
......@@ -1304,8 +1308,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
logger.warning(
f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}')
# back to old all_gather function signature
h = dist.all_gather(output_list, input_tensor, group=self.ds_process_group, async_op=True)
# back to old all_gather function
h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True)
launch_handles.append(h)
# Wait ensures the operation is enqueued, but not necessarily complete.
......@@ -1327,16 +1331,16 @@ class Init(InsertPostInitMethodToModuleSubClasses):
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
tensor_size = partition_size * self.world_size
tensor_size = partition_size * self.num_partitions
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device)
flat_tensor.requires_grad = False
partitions = []
for i in range(self.world_size):
for i in range(self.num_partitions):
start = partition_size * i
partitions.append(flat_tensor.narrow(0, start, partition_size))
if i == self.rank:
if i == self.get_partition_rank():
offset = 0
for param in param_list:
param_numel = param.ds_tensor.ds_numel
......@@ -1345,7 +1349,10 @@ class Init(InsertPostInitMethodToModuleSubClasses):
offset += param_numel
dist.all_gather(partitions, partitions[self.rank], group=self.ds_process_group, async_op=False)
dist.all_gather(partitions,
partitions[self.get_partition_rank()],
group=self.get_partition_dp_group(param),
async_op=False)
param_offset = 0
for param in param_list:
......@@ -1353,7 +1360,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
param_size = param.ds_numel
replicated_tensor = torch.empty(param.ds_shape, dtype=param.dtype, device=self.local_device)
for i in range(self.world_size):
for i in range(self.num_partitions):
start = i * partition_size
......@@ -1391,7 +1398,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
# For these ranks the output of reduce scatter is a separate buffer and needs
# to be copied in
partition_size = param.ds_tensor.ds_numel
start = self.rank * partition_size
start = self.get_partition_rank() * partition_size
end = start + partition_size
#print_rank_0("REduce scatter was executed for praam {param.ds_id}")
if start < param.ds_numel and end > param.ds_numel:
......@@ -1403,10 +1410,10 @@ class Init(InsertPostInitMethodToModuleSubClasses):
partition_size = param.ds_tensor.ds_numel
#output = torch.empty(partition_size, dtype=param.dtype, device=param.device)
total_size = partition_size * self.world_size
total_size = partition_size * self.num_partitions
input_list = []
for i in range(self.world_size):
for i in range(self.num_partitions):
start = i * partition_size
end = start + partition_size
......@@ -1423,8 +1430,11 @@ class Init(InsertPostInitMethodToModuleSubClasses):
#print("after reduce scatter gradients")
input_list.append(input)
rank = dist.get_rank(group=self.ds_process_group)
handle = dist.reduce_scatter(input_list[rank], input_list, group=self.ds_process_group, async_op=True)
rank = dist.get_rank(group=self.get_partition_dp_group(param))
handle = dist.reduce_scatter(input_list[rank],
input_list,
group=self.get_partition_dp_group(param),
async_op=True)
return handle, input_list[rank]
......@@ -1436,6 +1446,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
self._partition_gradient(param, partition_buffer=partition_buffer, accumulate=accumulate)
def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
#import pdb;pdb.set_trace()
# param.grad=None
# param.grad.test()
......@@ -1452,7 +1463,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
assert partition_buffer.numel(
) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
rank = dist.get_rank(group=self.ds_process_group)
rank = dist.get_rank(group=self.get_partition_dp_group(param))
start = partition_size * rank
end = start + partition_size
......@@ -1496,6 +1507,22 @@ class Init(InsertPostInitMethodToModuleSubClasses):
param.grad.data = dest_tensor_full_buffer.data
see_memory_usage("After partitioning gradients", force=False)
def get_partition_dp_group(self, param):
return param.ds_process_group
def get_partition_rank(self):
"""subclass can overload to specify different relative rank in
parameter partition group"""
return self.rank
@property
def num_partitions(self):
return self.dp_world_size
def get_dp_process_group(self):
""" Return the communication group with all data-parallel ranks """
return self.ds_process_group
class GatheredParameters:
......
......@@ -147,17 +147,17 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0
self.parameter_offload = DeepSpeedZeRoOffload(module=module,
timers=timers,
ds_config=ds_config,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
offload_param_config=offload_optimizer_config,
mpu=mpu)
self.parameter_offload = self.initialize_ds_offload(module=module,
timers=timers,
ds_config=ds_config,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
offload_optimizer_config=offload_optimizer_config,
mpu=mpu)
self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
......@@ -165,21 +165,21 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.module = module
self.elastic_checkpoint = elastic_checkpoint
self.__inf_or_nan_tracker: Tensor = torch.zeros(1,
dtype=torch.bool,
device=get_accelerator().current_device_name(),
requires_grad=False)
self.inf_or_nan_tracker: Tensor = torch.zeros(1,
dtype=torch.bool,
device=get_accelerator().current_device_name(),
requires_grad=False)
self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam)
self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu
### streams used for overlapping computation with communication
self.__reduce_and_partition_stream = get_accelerator().Stream() if overlap_comm else get_accelerator(
self.reduce_and_partition_stream = get_accelerator().Stream() if overlap_comm else get_accelerator(
).default_stream()
############################################################################
self.__n_caching_allocator_flushes = 0
self.n_caching_allocator_flushes = 0
#-------------Stage 3 Setup-------------------#
......@@ -261,12 +261,12 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.swap_optimizer:
self._configure_tensor_swapping(offload_optimizer_config, aio_config)
self.__params_in_ipg_bucket: List[Parameter] = []
self.params_in_ipg_bucket = []
self.is_gradient_accumulation_boundary: bool = True
self.__param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
# TODO. make this configurable via JSON
self.__max_param_reduce_events: int = 2
self.max_param_reduce_events: int = 2
self.param_dict = {}
......@@ -340,6 +340,32 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
def destroy(self):
self.parameter_offload.destroy()
def initialize_ds_offload(
self,
module,
timers,
ds_config,
overlap_comm,
prefetch_bucket_size,
max_reuse_distance,
max_live_parameters,
param_persistence_threshold,
model_persistence_threshold,
offload_optimizer_config,
mpu,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
ds_config=ds_config,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
offload_param_config=offload_optimizer_config,
mpu=mpu)
def _get_trainable_parameter_groups(self):
param_groups = []
for param_group in self.optimizer.param_groups:
......@@ -479,7 +505,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
@property
def elements_in_ipg_bucket(self):
return sum(p.ds_numel for p in self.__params_in_ipg_bucket)
return sum(p.ds_numel for p in self.params_in_ipg_bucket)
def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
'''If flat buffer is None then the parameters in the param_list are
......@@ -928,7 +954,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.__reduce_and_partition_ipg_grads()
self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)
self.__reduce_and_partition_stream.synchronize()
self.reduce_and_partition_stream.synchronize()
# if dist.get_rank() == 0:
# logger.info("Params already reduced %s", self.params_already_reduced)
......@@ -1026,11 +1052,11 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
@instrument_w_nvtx
@torch.no_grad()
def __add_grad_to_ipg_bucket(self, param: Parameter) -> None:
self.__reduce_and_partition_stream.wait_stream(get_accelerator().default_stream())
self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream())
if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel() < self.reduce_bucket_size:
# move the gradient to a contiguous buffer
with get_accelerator().stream(self.__reduce_and_partition_stream):
with get_accelerator().stream(self.reduce_and_partition_stream):
# move the parameter's gradient to the contiguous flat buffer
new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow(0, self.elements_in_ipg_bucket,
param.grad.numel()).view_as(param.grad)
......@@ -1038,40 +1064,40 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
param.grad.record_stream(get_accelerator().current_stream())
param.grad.data = new_grad_tensor
self.__params_in_ipg_bucket.append(param)
self.params_in_ipg_bucket.append(param)
@instrument_w_nvtx
@torch.no_grad()
def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
if not self.__params_in_ipg_bucket:
if not self.params_in_ipg_bucket:
return
for param in self.__params_in_ipg_bucket:
for param in self.params_in_ipg_bucket:
if param.grad.numel() != param.ds_numel:
raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter "
f"gradients whose size is not same as the params")
self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id)
assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len(self.__params_in_ipg_bucket)
assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket)
while self.__param_reduce_events and self.__param_reduce_events[0].query():
self.__param_reduce_events.popleft()
if len(self.__param_reduce_events) > self.__max_param_reduce_events:
self.__param_reduce_events.popleft().synchronize()
while self.param_reduce_events and self.param_reduce_events[0].query():
self.param_reduce_events.popleft()
if len(self.param_reduce_events) > self.max_param_reduce_events:
self.param_reduce_events.popleft().synchronize()
with get_accelerator().stream(self.__reduce_and_partition_stream):
with get_accelerator().stream(self.reduce_and_partition_stream):
if safe_mode:
assert_ints_same_as_other_ranks([p.ds_id for p in self.__params_in_ipg_bucket])
assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket])
grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket)
self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions)
grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)
self.partition_grads(self.params_in_ipg_bucket, grad_partitions)
self.__params_in_ipg_bucket.clear()
self.params_in_ipg_bucket.clear()
event = get_accelerator().Event()
event.record()
self.__param_reduce_events.append(event)
self.param_reduce_events.append(event)
@instrument_w_nvtx
def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]:
......@@ -1154,9 +1180,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
return total_norm
@instrument_w_nvtx
def __partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
offload_fp32_gradients = {}
offload_fp32_offsets = {}
buffers = []
for param, grad_partition in zip(params_to_release, grad_partitions):
contains_real_data = param.partition_numel() * dist.get_rank(self.dp_process_group) < param.ds_numel
......@@ -1167,6 +1194,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# move or accumulate gradient partition to target buffer
grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel())
buffers.append(grad_buffer)
if self.micro_step_id == 0: # don't accumulate
grad_buffer.copy_(grad_partition, non_blocking=True)
# ensure grad buffer is a CUDA buffer to speed up the next few
......@@ -1184,14 +1212,14 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# operations and so it can be used asynchronously
grad_buffer = cuda_grad_buffer
if hasattr(self.__inf_or_nan_tracker, "logical_or_"):
self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any())
self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any())
if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any()
self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any()
self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0
self.inf_or_nan_tracker += torch.isinf(grad_buffer).any()
self.inf_or_nan_tracker += torch.isnan(grad_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0
# offload the gradient partition if applicable
if self.offload_optimizer:
......@@ -1221,6 +1249,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i],
gradient_offsets=offload_fp32_offsets[i],
gradient_tensors=offload_fp32_gradients[i])
return buffers
def reduce_ready_partitions_and_remove_grads(self, param, i):
#print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
......@@ -1782,7 +1811,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# warn user about caching allocator flushes
memory_stats = get_accelerator().memory_stats()
alloc_retries = memory_stats["num_alloc_retries"] if memory_stats != None else 0
if alloc_retries > self.__n_caching_allocator_flushes:
if alloc_retries > self.n_caching_allocator_flushes:
if dist.get_rank() == 0:
logger.warning(
"%d pytorch allocator cache flushes since last step. this happens "
......@@ -1792,8 +1821,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
"make the cache flushes go away consider adding "
"get_accelerator().empty_cache() calls in your training loop to ensure "
"that all ranks flush their caches at the same time",
alloc_retries - self.__n_caching_allocator_flushes)
self.__n_caching_allocator_flushes = alloc_retries
alloc_retries - self.n_caching_allocator_flushes)
self.n_caching_allocator_flushes = alloc_retries
def dump_pre_step_gradients(self, debug_fp32_grads):
# Dump gradient norms for debugging
......@@ -1855,9 +1884,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
@instrument_w_nvtx
def has_overflow(self, partition_gradients=True):
if partition_gradients:
with get_accelerator().stream(self.__reduce_and_partition_stream):
self.local_overflow = bool(self.__inf_or_nan_tracker.item())
self.__inf_or_nan_tracker.zero_()
with get_accelerator().stream(self.reduce_and_partition_stream):
self.local_overflow = bool(self.inf_or_nan_tracker.item())
self.inf_or_nan_tracker.zero_()
overflow = self.local_overflow
#overflow = self.has_overflow_partitioned_grads_serial()
......@@ -1931,7 +1960,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
"""get fp32 gradient partition dictionary
accessed as grad_dict[parameter_group_index][parameter_index]
"""
self.__reduce_and_partition_stream.synchronize()
self.reduce_and_partition_stream.synchronize()
grad_dict = collections.defaultdict(dict)
if self.offload_optimizer:
for group in self.fp16_groups:
......@@ -1965,7 +1994,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if not param.requires_grad:
return None
self.__reduce_and_partition_stream.synchronize()
self.reduce_and_partition_stream.synchronize()
if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
......@@ -1980,7 +2009,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if not param.requires_grad:
return None
self.__reduce_and_partition_stream.synchronize()
self.reduce_and_partition_stream.synchronize()
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
if self._swappable_optimizer_subgroup(group_idx):
......
......@@ -155,7 +155,39 @@ Example ZeRO-3 Configurations
...
}
MiCS Configurations
===================
All MiCS configurations are set with `DeepSpeedZeroConfig`. MiCS assumes ZeRO
stage 3 optimization is enabled. For now, there are two configuration fields of
MiCS `mics_shard_size` and `mics_hierarchical_params_gather`. `mics_shard_size`
controls how many devices are used for partitioning the model states.
`mics_hierarchical_params_gather` controls whether we use a two-stage
hierarchical way to gather parameters in the forward computation.
`mics_hierarchical_params_gather` is useful when model states are partitioned
across multiple nodes and the cross-node bandwidth is slow. By default this is
turned off.
Example MiCS Configurations
===========================
#. Use MiCS to partition the model states (including optimizer states,
gradients, and parameters). The following config example partitions the model
states to eight devices, and assumes the eight devices are located within a
single node (`mics_hierarchical_params_gather` is `False`).
.. code-block:: python
:emphasize-lines: 3
{
"zero_optimization": {
"stage": 3,
"mics_shard_size": 8,
"mics_hierarchical_params_gather": False,
},
...
}
Assumptions
===========
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import os
import json
import argparse
import torch
import deepspeed
from torch.utils.data.distributed import DistributedSampler
import deepspeed.comm as dist
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
if empty_grad:
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, x, y):
hidden = x
hidden = self.linear(hidden)
return self.cross_entropy_loss(hidden, y)
def create_config_from_dict(tmpdir, config_dict):
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as fd:
json.dump(config_dict, fd)
return config_path
def get_data_loader(model, total_samples, hidden_dim, device):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
sampler = DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
return train_loader
def get_args(tmpdir, config_dict):
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--zero', type=int, default=3)
args = parser.parse_args() #args=''
config_dict["zero_optimization"]["stage"] = args.zero
# print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
config_path = create_config_from_dict(tmpdir, config_dict)
args.deepspeed_config = config_path
return args
def print0(msg):
if dist.get_rank() == 0:
print(msg, flush=True)
rank = int(os.environ['RANK'])
print('seed:', 2222 + rank)
torch.random.manual_seed(2222 + rank)
config_dict = {
"train_batch_size": 8,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 15
},
"zero_optimization": {
"stage": 3,
"reduce_bucket_size": 20,
"mics_shard_size": 4,
"mics_hierarchical_params_gather": True,
"stage3_model_persistence_threshold": 10
}
}
# "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
hidden_dim = 32
# with deepspeed.zero.Init():
model = SimpleModel(hidden_dim, empty_grad=False)
# print('------> init model with deepspeed.zero.Init()')
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=True)
def print_params(tag, model):
if dist.get_rank() == 0:
for n, p in model.named_parameters():
print0("{} {}:{}".format(tag, n, p))
data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device)
#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if dist.get_rank() == 0:
print("LOSS:", loss.item())
model.backward(loss)
model.step()
#print_params('step={}'.format(n), model)
if n == 5: break
......@@ -82,12 +82,13 @@ config_dict = {
},
"zero_optimization": {
"stage": 0,
"reduce_bucket_size": 20
"reduce_bucket_size": 20,
"stage3_model_persistence_threshold": 10
}
}
# "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
hidden_dim = 4
hidden_dim = 32
model = SimpleModel(hidden_dim, empty_grad=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册