未验证 提交 ad168a69 编写于 作者: M Michael Wyatt 提交者: GitHub

Fix for dist not being initialized when constructing main config (#3324)

* move dist init out of Engine
上级 dd8df20f
......@@ -15,6 +15,7 @@ from packaging import version as pkg_version
from . import ops
from . import module_inject
from .accelerator import get_accelerator
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
......@@ -50,6 +51,9 @@ __version_major__, __version_minor__, __version_patch__ = _parse_version(__versi
__git_hash__ = git_hash
__git_branch__ = git_branch
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None
def initialize(args=None,
model: torch.nn.Module = None,
......@@ -119,6 +123,11 @@ def initialize(args=None,
assert model is not None, "deepspeed.initialize requires a model"
global dist
from deepspeed import comm as dist
dist_backend = get_accelerator().communication_backend_name()
dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required)
# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params
......
......@@ -3,48 +3,5 @@
# DeepSpeed Team
import torch
from .utils import *
from deepspeed import utils
supported_torch_version = False
# See more details at: https://github.com/pytorch/pytorch/pull/48767
# The PG API in torch versions lesser than 1.8 are different so it is
# non-trivial to support both in the same API. We will just use the
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
if older_torch():
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
supported_torch_version = False
from torch.distributed import *
def get_world_group():
return group.WORLD
def get_global_rank(group, group_rank):
if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
else:
from torch.distributed.distributed_c10d import _get_global_rank
return _get_global_rank(group, group_rank)
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import all_gather, get_world_size
from torch import chunk
output_tensors = list(chunk(output_tensor, get_world_size(group)))
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)
def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import reduce_scatter, get_world_size
from torch import chunk
input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
return reduce_scatter(output_tensor, input_tensor_lst, group=group)
def configure(deepspeed_config=None, enabled=None, prof_all=None, prof_ops=None, verbose=None):
utils.logger.warn("Communication logging is not supported in torch versions older than 1.8")
else:
supported_torch_version = True
from .comm import *
from .comm import *
......@@ -21,7 +21,6 @@ class TorchBackend(Backend):
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True
......
......@@ -4,25 +4,10 @@
# DeepSpeed Team
import os
import torch
import inspect
from deepspeed.utils import get_caller_func
def older_torch():
'''
Helper to lookup torch version. For versions less than 1.8, torch.dist
used torch.distributed.group.WORLD as the default group argument instead of None.
See more details at: https://github.com/pytorch/pytorch/pull/48767
'''
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
return True
else:
return False
def get_local_rank_from_launcher():
# DeepSpeed launcher will set it so get from there
......
......@@ -21,6 +21,7 @@ from typing import Callable, Dict, Union, Iterable
import deepspeed
from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
......@@ -96,9 +97,6 @@ from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.runtime.config import DtypeEnum
# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
DeepSpeedOptimizerCallable = \
......@@ -232,8 +230,6 @@ class DeepSpeedEngine(Module):
self.checkpoint_engine = None
global dist
from deepspeed import comm as dist
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
......@@ -243,22 +239,6 @@ class DeepSpeedEngine(Module):
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}
from deepspeed.comm import supported_torch_version
# This supported_torch_version check is for torch1.2 compatibility only
if supported_torch_version:
dist.init_distributed(dist_backend=self.dist_backend, dist_init_required=dist_init_required)
else:
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
if dist_init_required is False:
assert (
dist.is_initialized() is True
), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_backend)
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册