未验证 提交 5aa58b38 编写于 作者: J Jeff Rasley 提交者: GitHub

Init distributed torch only if needed (#108)

* add auto-detect to torch dist init

* update tests to infer distributed init status

* prevent crash if dist_init_required is True but already initiliazed

* only init if safe to do so (forgot to add this file in prev commit)
上级 6efee45c
......@@ -32,7 +32,7 @@ def initialize(args,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=True,
dist_init_required=None,
collate_fn=None):
r"""Initialize the DeepSpeed Engine.
......@@ -56,7 +56,8 @@ def initialize(args,
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
dist_init_required: Optional: Initializes torch.distributed
dist_init_required: Optional: None will auto-initialize torch.distributed if needed,
otherwise the user can force it to be initialized or not via boolean.
collate_fn: Optional: Merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
......
......@@ -97,7 +97,7 @@ class DeepSpeedLight(Module):
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=True,
dist_init_required=None,
collate_fn=None):
super(DeepSpeedLight, self).__init__()
......@@ -119,8 +119,19 @@ class DeepSpeedLight(Module):
self.gradient_average = True
self.warn_unscaled_loss = True
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
self.dist_backend = "nccl"
if dist_init_required:
dist.init_process_group(backend="nccl")
if not dist.is_initialized():
logging.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
else:
logging.warning(
"Was given dist_init_required=True but detected that torch"
"distributed was already initialized, cannot initialize twice.")
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
......
......@@ -142,8 +142,7 @@ def test_deprecated_deepscale_config(tmpdir):
def _test_deprecated_deepscale_config(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
......@@ -154,3 +153,45 @@ def test_deprecated_deepscale_config(tmpdir):
model.step()
_test_deprecated_deepscale_config(args=args, model=model, hidden_dim=hidden_dim)
def test_dist_init_true(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}
config_path = create_config_from_dict(tmpdir, config_dict)
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepscale_config = config_path
args.local_rank = 0
hidden_dim = 10
model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1])
def _test_dist_init_true(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=True)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)
......@@ -32,8 +32,7 @@ def test_lamb_fp16_basic(tmpdir):
def _test_lamb_fp16_basic(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
......@@ -70,8 +69,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
......@@ -102,8 +100,7 @@ def test_adamw_fp16_basic(tmpdir):
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
dist_init_required=False)
optimizer=optimizer)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
......@@ -134,8 +131,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
dist_init_required=False)
optimizer=optimizer)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册