未验证 提交 ddffbae0 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

Remove duplicate clip grad function in deepspeed (#1333)

* Remove the wrong function with duplicate name

* fix format.

* add mpu check. fix tests.
上级 b9ece257
......@@ -292,9 +292,15 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
else:
total_norm = 0
for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0):
param_norm = p.grad.data.norm(norm_type)
if mpu is not None:
if (mpu.get_model_parallel_rank() == 0
) or is_model_parallel_parameter(p):
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item()**norm_type
else:
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
......@@ -306,9 +312,8 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
# Need to average total_norm across different GPUs due to the presence of moe params
pg = groups.get_data_parallel_group()
scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))
scaled_norm_tensor = torch.tensor(scaled_norm,
device=self.fp32_groups_flat[i].device,
dtype=torch.float)
scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)])
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norm = scaled_norm_tensor.item()
......@@ -319,105 +324,6 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
return total_norm
def clip_grad_norm_(parameters,
max_norm: float,
norm_type: float = 2.0,
mpu=None,
ignore_expert_params: bool = True) -> float:
"""Clips gradient norm of a list of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place. Taken from Nvidia Megatron.
Additionally, we also handle MoE parameters
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
ignore_expert_params (bool): ignore mixture of experts parameters.
If set to False, then we do an (expensive) all_reduce of
the MoE parameters. If set to True, then we ignore the MoE expert
parameters during norm computation
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
shared_params, expert_params = split_params_into_shared_and_expert_params(parameters)
if norm_type == inf:
total_shared_norm = max(p.grad.data.abs().max() for p in shared_params)
total_expert_norm = float("-inf")
if not ignore_expert_params and len(expert_params) > 0:
local_expert_norm = max(p.grad.data.abs().max() for p in expert_params)
local_expert_norm_cuda = torch.cuda.FloatTensor([float(local_expert_norm)])
# Get max across all experts in an expert parallel group
dist.all_reduce(local_expert_norm_cuda,
op=dist.ReduceOp.MAX,
group=groups.get_expert_parallel_group())
total_expert_norm = local_expert_norm_cuda[0].item()
total_norm = max(total_shared_norm, total_expert_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_shared_norm = 0.
for p in shared_params:
if mpu is not None:
if (mpu.get_model_parallel_rank() == 0
) or is_model_parallel_parameter(p):
param_norm = p.grad.data.float().norm(norm_type)
total_shared_norm += param_norm.item()**norm_type
else:
param_norm = p.grad.data.float().norm(norm_type)
total_shared_norm += param_norm.item()**norm_type
total_expert_norm = 0.
if not ignore_expert_params and len(expert_params) > 0:
for p in expert_params:
if mpu is not None:
# TODO(bapatra): I am not sure what the right thing to do here is
if (mpu.get_model_parallel_rank() == 0
) or is_model_parallel_parameter(p):
param_norm = p.grad.data.float().norm(norm_type)
total_expert_norm += param_norm.item()**norm_type
else:
param_norm = p.grad.data.float().norm(norm_type)
total_expert_norm += param_norm.item()**norm_type
total_expert_norm_cuda = torch.cuda.FloatTensor([float(total_expert_norm)])
dist.all_reduce(total_expert_norm_cuda,
op=dist.ReduceOp.SUM,
group=groups.get_expert_parallel_group())
total_expert_norm = total_expert_norm_cuda[0].item()
total_norm = total_shared_norm + total_expert_norm
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
# now rescale based on total norm
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
def get_grad_norm(parameters, norm_type=2, mpu=None):
"""Get grad norm of an iterable of parameters.
......
......@@ -22,32 +22,32 @@ def test_call_to_str():
assert c2s('hello', 1138, val=3) == 'hello(1138, val=3)'
@pytest.mark.parametrize('ignore_expert_params', [(False), (True)])
def test_clip_grad_norm_(ignore_expert_params: bool):
def test_clip_grad_norm_():
@distributed_test(world_size=[2])
def _test_clip_grad_norm_(ignore_expert_params: bool) -> None:
def _test_clip_grad_norm_() -> None:
param1 = torch.nn.Parameter(torch.Tensor([0]))
param1.grad = torch.Tensor([1])
param2 = torch.nn.Parameter(torch.Tensor([0]))
param2.grad = torch.Tensor([dist.get_rank() + 1])
param2.allreduce = False
# param2 is now MoE parameter
param2.allreduce = False
parameters = [param1, param2]
if not ignore_expert_params:
groups.initialize_model_parallel(1)
groups.initialize_expert_parallel(2)
norm = ds_utils.clip_grad_norm_(parameters,
max_norm=0.1,
ignore_expert_params=ignore_expert_params)
if ignore_expert_params:
# Ignore param2.grad
assert norm == 1.0
else:
# Use param2.grad from both ranks
assert torch.isclose(torch.Tensor([norm]), torch.sqrt(torch.Tensor([6])))
return _test_clip_grad_norm_(ignore_expert_params)
groups.initialize_model_parallel(1)
groups.initialize_expert_parallel(2)
norm = ds_utils.clip_grad_norm_(parameters, max_norm=0.1)
norm = torch.Tensor([norm]).to(dist.get_rank())
world_size = dist.get_world_size()
gathered_norm = [torch.zeros(1).cuda() for i in range(world_size)]
torch.distributed.all_gather(gathered_norm, norm)
assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1"
return _test_clip_grad_norm_()
@pytest.mark.parametrize("check_using_norm", [(False), (True)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册