未验证 提交 52c5a936 编写于 作者: J Jeff Rasley 提交者: GitHub

add allreduce test (#7)

* add allreduce test

* comment out set rank to cuda for now

* switched back to gloo
上级 b61a2217
......@@ -11,7 +11,7 @@ import pytest
DEEPSPEED_UNIT_WORKER_TIMEOUT = 5
def distributed_test(world_size=2):
def distributed_test(world_size=2, backend='gloo'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
torch.distributed, and catching of errors.
......@@ -33,14 +33,14 @@ def distributed_test(world_size=2):
"""Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='nccl',
dist.init_process_group(backend=backend,
init_method='env://',
rank=local_rank,
world_size=num_procs)
# XXX temporarily disabled due to CUDA runtime error?
#if torch.cuda.is_available():
# torch.cuda.set_device(local_rank)
# torch.cuda.set_device(local_rank)
run_func(*func_args, **func_kwargs)
......
import torch
import torch.distributed as dist
from common import distributed_test
......@@ -26,3 +27,11 @@ def test_dist_args(number, color):
"""Ensure that we can parse args to distributed_test decorated functions. """
_test_dist_args_helper(number, color=color)
@distributed_test(world_size=2)
def test_dist_allreduce():
x = torch.ones(1, 3) * (dist.get_rank() + 1)
result = torch.ones(1, 3) * 3
dist.all_reduce(x)
assert torch.all(x == result)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册