未验证 提交 52c5a936 编写于 作者: J Jeff Rasley 提交者: GitHub

add allreduce test (#7)

* add allreduce test

* comment out set rank to cuda for now

* switched back to gloo
上级 b61a2217
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
DEEPSPEED_UNIT_WORKER_TIMEOUT = 5 DEEPSPEED_UNIT_WORKER_TIMEOUT = 5
def distributed_test(world_size=2): def distributed_test(world_size=2, backend='gloo'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner. """A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of This decorator manages the spawning and joining of processes, initialization of
torch.distributed, and catching of errors. torch.distributed, and catching of errors.
...@@ -33,14 +33,14 @@ def distributed_test(world_size=2): ...@@ -33,14 +33,14 @@ def distributed_test(world_size=2):
"""Initialize torch.distributed and execute the user function. """ """Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500' os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='nccl', dist.init_process_group(backend=backend,
init_method='env://', init_method='env://',
rank=local_rank, rank=local_rank,
world_size=num_procs) world_size=num_procs)
# XXX temporarily disabled due to CUDA runtime error? # XXX temporarily disabled due to CUDA runtime error?
#if torch.cuda.is_available(): #if torch.cuda.is_available():
# torch.cuda.set_device(local_rank) # torch.cuda.set_device(local_rank)
run_func(*func_args, **func_kwargs) run_func(*func_args, **func_kwargs)
......
import torch
import torch.distributed as dist import torch.distributed as dist
from common import distributed_test from common import distributed_test
...@@ -26,3 +27,11 @@ def test_dist_args(number, color): ...@@ -26,3 +27,11 @@ def test_dist_args(number, color):
"""Ensure that we can parse args to distributed_test decorated functions. """ """Ensure that we can parse args to distributed_test decorated functions. """
_test_dist_args_helper(number, color=color) _test_dist_args_helper(number, color=color)
@distributed_test(world_size=2)
def test_dist_allreduce():
x = torch.ones(1, 3) * (dist.get_rank() + 1)
result = torch.ones(1, 3) * 3
dist.all_reduce(x)
assert torch.all(x == result)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册