From 0457bb1cb661c4783712489b529317437e086bc7 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 30 Sep 2021 21:25:42 -0700 Subject: [PATCH] Add assert to ensure we don't skip unsupported grad dtypes (#1418) --- deepspeed/runtime/engine.py | 21 ++++++------ docs/_pages/config-json.md | 2 +- tests/unit/common.py | 3 ++ tests/unit/test_sparse_grads.py | 61 +++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 12 deletions(-) create mode 100644 tests/unit/test_sparse_grads.py diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e8c89116..52bbc516 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -78,14 +78,18 @@ except ImportError: def split_half_float_double_csr(tensors): - dtypes = [ + supported_types = [ "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", CSRTensor.type() ] + + for t in tensors: + assert t.type() in supported_types, f"attempting to reduce an unsupported grad type: {t.type()}" + buckets = [] - for i, dtype in enumerate(dtypes): + for i, dtype in enumerate(supported_types): bucket = [t for t in tensors if t.type() == dtype] if bucket: buckets.append((dtype, bucket)) @@ -1761,7 +1765,7 @@ class DeepSpeedEngine(Module): grads.append(grad_data) split_buckets = split_half_float_double_csr(grads) - for i, bucket_tuple in enumerate(split_buckets): + for _, bucket_tuple in enumerate(split_buckets): bucket_type, bucket = bucket_tuple if self.pipeline_parallelism: @@ -1770,22 +1774,17 @@ class DeepSpeedEngine(Module): dp_group = groups.get_data_parallel_group() if bucket_type == CSRTensor.type(): - # TODO: do we have to do something here? self.csr_allreduce_no_retain(bucket, dp_group=dp_group) - #groups.get_data_parallel_group() if self.pipeline_parallelism else self.mpu.get_data_parallel_group()) else: - self.allreduce_no_retain( - bucket, - dp_group=dp_group, - #groups.get_data_parallel_group(), - numel_per_bucket=elements_per_buffer) + self.allreduce_no_retain(bucket, + dp_group=dp_group, + numel_per_bucket=elements_per_buffer) if self.has_moe_layers: expert_split_buckets = split_half_float_double_csr(expert_grads) for i, bucket_tuple in enumerate(expert_split_buckets): bucket_type, bucket = bucket_tuple if bucket_type == CSRTensor.type(): - # TODO: do we have to do something here? self.csr_allreduce_no_retain(bucket, groups.get_expert_data_parallel_group()) else: diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index fda6ba72..af2d2cc8 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -173,7 +173,7 @@ Example of **scheduler** | Description | Default | | ------------------------------------------------------------------------------------------------------------------------ | ------- | -| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` | +| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. This feature is essentially deprecated as we don't see use cases for it as much anymore. It should be noted that this feature is not compatible with [torch.sparse](https://pytorch.org/docs/stable/sparse.html) related features. | `false` | ### FP16 training options diff --git a/tests/unit/common.py b/tests/unit/common.py index c2afbc97..a6fb7d8b 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -42,6 +42,9 @@ def distributed_test(world_size=2, backend='nccl'): os.environ['RANK'] = str(local_rank) os.environ['WORLD_SIZE'] = str(num_procs) + # turn off NCCL logging if set + os.environ.pop('NCCL_DEBUG', None) + deepspeed.init_distributed(dist_backend=backend) if torch.cuda.is_available(): diff --git a/tests/unit/test_sparse_grads.py b/tests/unit/test_sparse_grads.py new file mode 100644 index 00000000..458acaf1 --- /dev/null +++ b/tests/unit/test_sparse_grads.py @@ -0,0 +1,61 @@ +import torch +import torch.distributed as dist +import deepspeed +import pytest +from common import distributed_test + + +def test_sparse_adam(tmpdir): + config_dict = {"train_batch_size": 2, "steps_per_print": 1, "sparse_gradients": True} + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.EmbeddingBag(10, 3, mode="sum", sparse=True) + self.linear = torch.nn.Linear(3, 1) + + def forward(self, x, offsets): + return self.linear(self.emb(x, offsets)) + + class Adam(torch.optim.Optimizer): + def __init__(self, dense_params, sparse_params): + super().__init__(dense_params + sparse_params, defaults={}) + self.adam = torch.optim.Adam(dense_params) + self.adam_sparse = torch.optim.SparseAdam(sparse_params) + + @torch.no_grad() + def step(self, closure=None): + loss_1 = self.adam.step(closure) + loss_2 = self.adam_sparse.step(closure) + + if loss_1 is not None and loss_2 is not None: + return loss_1 + loss_2 + return loss_1 or loss_2 + + model = Model() + optimizer = Adam(list(model.linear.parameters()), list(model.emb.parameters())) + + @distributed_test(world_size=[2]) + def _test(model, optimizer): + engine, _, _, _ = deepspeed.initialize(model=model, + optimizer=optimizer, + config=config_dict) + loss = torch.nn.BCEWithLogitsLoss() + x = torch.tensor([1, + 2, + 4, + 5, + 4, + 3, + 2, + 9], + dtype=torch.long, + device=engine.device) + offsets = torch.tensor([0, 4], dtype=torch.long, device=engine.device) + y = torch.tensor([[1.0], [0.0]], device=engine.device) + res = engine(x, offsets) + with pytest.raises(AssertionError): + engine.backward(loss(res, y)) + engine.step() + + _test(model, optimizer) -- GitLab