未验证 提交 a6dba72a 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

NCCL-based 1-bit Adam + Code Refactor for Comm. Backends (#594)

* NCCL based 1-bit Implementation + Refactor to add communication backends (#593)

* add nccl 1-bit optim.

* temporary commit to save stuff.

* Use dist collectives instead of mpi routines.

* remove old code for comm.

* Fix bugs. still does not work.

* modify to test the nccl side code path

* Initial gather impl. Works intra-node.

* Updates to comm. phase 2. nccl comm. passed the tests.

* refactor code to introduce nccl/mpi as backends for onebit adam.

* Refactor updates to test/engine.

* Fix compile/runtime errors.

* simplify support for nccl/mpi backends.

* Add missign file

* Add compression backend in constructor. Revert later.

* modify test with some perf counting.

* Implement a true non-blocking gather for nccl side.

* Revert "Add compression backend in constructor. Revert later."

This reverts commit df8c40d3.

* improve the 1-bit adam test.

* Refactor comm. and compression backend in 1-bit adam.

* Fix the test.

* Fix runtime errors and typos in nccl backend

* fix mpi backend. modify tests.

* modify nccl perf test.

* fix mpi side errors.

* Add an mpi perf test

* Sync DSE.

* Remove old collectives file.

* Undo a typo.

* Graceful failure for torch versions that don't support nccl pt2pt.
上级 7300f3e3
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch
import cupy
import time
import numpy as np
from mpi4py import MPI
from deepspeed.runtime.compression.cupy import CupyBackend
class MpiBackend(object):
def __init__(self, cuda_aware):
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.cuda_aware = cuda_aware
self.compression_backend = CupyBackend()
def my_igather(self, rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req
def gather_cuda(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
def gather_host(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed
for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
# 2. use numpy buffers for communication
requests = []
for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
def allgather_cuda(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
def allgather_host(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros(
[comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)
numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()
# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
def compressed_allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):
all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()
if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None
compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.compression_backend.torch2cupy(
compensated_buffer_m)
compensated_buffer_m = None
cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
cupy_compensated_buffer_m,
self.size)
cupy_compensated_buffer_m = None
cupy_recvbuf_sign = cupy.zeros(
[self.size,
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
self.gather_cuda(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale = self.gather_host(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()
cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
self.size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.compression_backend.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(
1 / self.size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None
compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
cupy_compensated_server_m = self.compression_backend.torch2cupy(
compensated_server_m)
compensated_server_m = None
cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
cupy_compensated_server_m,
1)
cupy_recvbuf_sign_server = cupy.zeros(
[self.size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_worker_scale.dtype)
# Communication Phase 2
if self.cuda_aware:
self.allgather_cuda(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server = self.allgather_host(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(self.size,
-1)
cupy_recvbuf_sign_server = None
server_unpacked_sign = self.compression_backend.cupy2torch(
cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None
server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]
return buffer_m
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import torch
import torch.distributed as dist
import time
import cupy
import numpy as np
from deepspeed.runtime.compression.cupy import CupyBackend
class NcclBackend(object):
def __init__(self):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
self.rank = dist.get_rank(group=self.world_group)
self.size = dist.get_world_size(group=self.world_group)
self.compression_backend = CupyBackend()
def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
else:
recvbuf[rank] = sendbuf
else:
req.append(dist.isend(sendbuf, group=group, dst=root))
return req
def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
if rank == root:
for idx in range(size):
if idx != rank:
dist.recv(recvbuf[idx], src=idx, group=group)
else:
recvbuf[rank] = sendbuf
else:
dist.send(sendbuf, group=group, dst=root)
def compressed_allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):
all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()
if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None
compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.compression_backend.torch2cupy(compensated_buffer_m)
compensated_buffer_m = None
cupy_sign_list_packed = self.compression_backend.compress_by_chunk(cupy_compensated_buffer_m,
self.size)
cupy_compensated_buffer_m = None
cupy_recvbuf_sign = cupy.zeros(
[self.size,
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
sign_list_packed = [None] * self.size
for idx in range(self.size):
sign_list_packed[idx] = self.compression_backend.cupy2torch(cupy_sign_list_packed[idx])
recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign)
worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale)
recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale)
# communication phase 1
gather_start = time.time()
requests = []
for idx in range(self.size):
requests += self.my_igather(self.rank,
self.size,
self.world_group,
sign_list_packed[idx],
recvbuf_sign,
root=idx)
requests += self.my_igather(self.rank,
self.size,
self.world_group,
worker_scale,
recvbuf_scale,
root=idx)
for i in range(len(requests)):
requests[i].wait()
gather_end = time.time()
cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign)
cupy_recvbuf_scale = self.compression_backend.torch2cupy(recvbuf_scale)
cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
self.size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.compression_backend.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(1 / self.size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None
compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
cupy_compensated_server_m = self.compression_backend.torch2cupy(compensated_server_m)
compensated_server_m = None
cupy_server_sign_packed = self.compression_backend.compress_by_chunk(cupy_compensated_server_m, 1)
cupy_recvbuf_sign_server = cupy.zeros(
[self.size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
server_sign_packed = [None] * 1
recvbuf_sign_server = [None] * self.size
for idx in range(self.size):
recvbuf_sign_server[idx] = self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx])
server_sign_packed[0] = self.compression_backend.cupy2torch(cupy_server_sign_packed[0])
server_scale = self.compression_backend.cupy2torch(cupy_server_scale)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_worker_scale.dtype)
recvbuf_scale_server = [None] * self.size
for idx in range(self.size):
recvbuf_scale_server[idx] = self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx])
# Communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0])
dist.all_gather(recvbuf_scale_server, server_scale)
# need to convert from a tensor list to a single tensor
# dist.all_gather only provides a tensor list as the recv/output buffer
recvbuf_sign_server = torch.stack(recvbuf_sign_server)
cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(recvbuf_sign_server)
cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(self.size,
-1)
cupy_recvbuf_sign_server = None
server_unpacked_sign = self.compression_backend.cupy2torch(cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None
server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]
return buffer_m
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
class CupyBackend(object):
def __init__(self):
pass
def torch2cupy(self, tensor):
return cupy.fromDlpack(to_dlpack(tensor))
def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor.toDlpack())
def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
from mpi4py import MPI
import numpy as np
import cupy
def my_igather(rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req
def gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
def gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed
for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])
numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
# 2. use numpy buffers for communication
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign
for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale
MPI.Request.Waitall(requests)
# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])
cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()
return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
def allgather_cuda(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
def allgather_host(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros([comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)
numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()
# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
......@@ -664,8 +664,8 @@ class DeepSpeedEngine(Module):
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
from deepspeed.runtime.fp16.onebit.adam import Adam
optimizer = Adam(model_parameters, self, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
......
......@@ -6,18 +6,14 @@ import torch
import importlib
import numpy as np
import time
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
from deepspeed.utils.logging import logger
import torch.distributed as dist
from mpi4py import MPI
from deepspeed.runtime.custom_collectives import gather_cuda, gather_host, allgather_cuda, allgather_host
from deepspeed.utils.logging import logger
class OnebitAdam(torch.optim.Optimizer):
class Adam(torch.optim.Optimizer):
"""Implements the 1-bit Adam algorithm. Currently GPU-only.
For usage example please see, TODO DeepSpeed Tutorial
For usage example please see, https://www.deepspeed.ai/tutorials/onebit-adam/
It has been proposed in APMSqueeze (https://arxiv.org/abs/2008.11343)
Arguments:
......@@ -42,6 +38,8 @@ class OnebitAdam(torch.optim.Optimizer):
second moment estimate as in the original paper. (default: False)
cuda_aware (boolean, required): Set True if the underlying MPI implementation
supports CUDA-Aware communication. (default: False)
comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
from cupy. (default: 'deepspeed')
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
......@@ -60,10 +58,12 @@ class OnebitAdam(torch.optim.Optimizer):
weight_decay=0.,
max_grad_norm=0.,
amsgrad=False,
cuda_aware=False):
cuda_aware=False,
comm_backend_name='nccl'):
if amsgrad:
raise RuntimeError('1-bit Adam does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
......@@ -71,161 +71,39 @@ class OnebitAdam(torch.optim.Optimizer):
weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(OnebitAdam, self).__init__(params, defaults)
from mpi4py import MPI
super(Adam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
assert (dist.is_initialized())
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.comm_time = 0.0
self.step_time = 0.0
self.ave_step = 1
self.bk_time = 0.0
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
self.deepspeed = deepspeed
self.adam_freeze_key = False
self.initialize = False
self.freeze_step = freeze_step
self.cuda_aware = cuda_aware
def torch2cupy(self, tensor):
return cupy.fromDlpack(to_dlpack(tensor))
def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor.toDlpack())
def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed
def Compressed_Allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
rank,
world_size,
comm,
local_rank):
all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()
if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None
compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.torch2cupy(compensated_buffer_m)
compensated_buffer_m = None
cupy_sign_list_packed = self.compress_by_chunk(cupy_compensated_buffer_m,
world_size)
cupy_compensated_buffer_m = None
cupy_recvbuf_sign = cupy.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale = gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()
cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
world_size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.cupy2torch(cupy_recvbuf_scale).mul_(1 / world_size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None
compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None
compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.torch2cupy(server_scale)
cupy_compensated_server_m = self.torch2cupy(compensated_server_m)
compensated_server_m = None
cupy_server_sign_packed = self.compress_by_chunk(cupy_compensated_server_m, 1)
cupy_recvbuf_sign_server = cupy.zeros(
[world_size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale_server = cupy.zeros([world_size,
1],
dtype=cupy_worker_scale.dtype)
# Communication Phase 2
if self.cuda_aware:
allgather_cuda(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server = allgather_host(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
self.comm_backend_name = comm_backend_name
cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(world_size,
-1)
cupy_recvbuf_sign_server = None
# Empty initializer. Set handle based on the comm backend as follows.
self.comm_backend_handle = None
server_unpacked_sign = self.cupy2torch(cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None
if self.comm_backend_name == 'nccl':
assert torch.__version__.startswith("1.8."), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.comm_backend_handle = NcclBackend()
server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]
elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
self.comm_backend_handle = MpiBackend(cuda_aware)
return buffer_m
self.size = self.comm_backend_handle.size
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
def step(self, closure=None, grads=None):
"""Performs a single optimization step.
......@@ -337,13 +215,12 @@ class OnebitAdam(torch.optim.Optimizer):
if self.size > 1:
exp_avg.set_(
self.Compressed_Allreduce(exp_avg,
state['worker_error'],
state['server_error'],
self.rank,
self.size,
self.comm,
self.deepspeed.local_rank))
self.comm_backend_handle.compressed_allreduce(
exp_avg,
state['worker_error'],
state['server_error'],
self.deepspeed.local_rank))
if self.initialize:
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
......@@ -362,7 +239,7 @@ class OnebitAdam(torch.optim.Optimizer):
self.adam_freeze_key = False
self.initialize = True
print(
f"Finished the initialization step at rant {torch.distributed.get_rank()}"
f"Finished the initialization step at rank {torch.distributed.get_rank()}"
)
return loss
......
......@@ -4,7 +4,8 @@ import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
from deepspeed.runtime.comm.mpi import MpiBackend
comm = MPI.COMM_WORLD
size = comm.Get_size()
......@@ -12,18 +13,17 @@ rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-1:2245',
init_method='tcp://worker-0:2245',
world_size=size,
rank=rank)
dummy_model = [torch.nn.Parameter(torch.ones(10))]
# Set cuda_aware to True to use CUDA buffers for communication
dummy_optim = OnebitAdam(dummy_model, cuda_aware=True)
# Change cuda_aware to True to test out CUDA-Aware MPI communication
backend = MpiBackend(cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
# A simulated compression function using torch.distributed
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
......@@ -52,21 +52,20 @@ if tensor_size % (8 * size) != 0:
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
......@@ -74,13 +73,16 @@ diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
test_correctness = True
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for 1bit Adam at Rank {}'.format(rank))
if test_correctness:
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for MPI Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.comm.mpi import MpiBackend
# Configure wall clock timer
from deepspeed.utils.timer import SynchronizedWallClockTimer
from statistics import mean
timers = SynchronizedWallClockTimer()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-0:2245',
world_size=size,
rank=rank)
backend = MpiBackend(cuda_aware=False)
device = torch.device('cuda', rank % torch.cuda.device_count())
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
warmup = 10
iters = 100
local_rank = rank % torch.cuda.device_count()
# Warmup
for i in range(warmup):
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
time_list = []
for i in range(iters):
timers('compressed_allreduce').start()
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
timers('compressed_allreduce').stop()
time_list.append(timers('compressed_allreduce').elapsed())
timer_names = ['compressed_allreduce']
timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
places = 2
convert = 1e3
float_size = 4
if rank == 0:
for i in range(iters):
lat = time_list[i]
print("latency = ", lat * convert)
minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.comm.nccl import NcclBackend
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-0:2245',
world_size=size,
rank=rank)
backend = NcclBackend()
device = torch.device('cuda', rank % torch.cuda.device_count())
# A simulated compression function using torch.distributed
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
torch.distributed.barrier()
return a_server_compressed, worker_error, server_error
tensor_size = 100 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
test_correctness = True
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if test_correctness:
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print(
'Successfully passed the test for NCCL Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.comm.nccl import NcclBackend
# Configure wall clock timer
from deepspeed.utils.timer import SynchronizedWallClockTimer
from statistics import mean
timers = SynchronizedWallClockTimer()
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#TODO: Detect the hostname we are running on automatically
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-0:2245',
world_size=size,
rank=rank)
backend = NcclBackend()
device = torch.device('cuda', rank % torch.cuda.device_count())
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size
# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)
warmup = 10
iters = 100
local_rank = rank % torch.cuda.device_count()
# Warmup
for i in range(warmup):
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
time_list = []
for i in range(iters):
timers('compressed_allreduce').start()
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
timers('compressed_allreduce').stop()
time_list.append(timers('compressed_allreduce').elapsed())
timer_names = ['compressed_allreduce']
timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
places = 2
convert = 1e3
float_size = 4
if rank == 0:
for i in range(iters):
lat = time_list[i]
print("latency = ", lat * convert)
minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册