未验证 提交 8a12f459 编写于 作者: L lilong12 提交者: GitHub

[Cherry-pick] fix the bug for nccl barrier and alltoall (#42042)

* fix_nccl_barrier (#41970)

* be compatible with the old version of alltoall (#42007)
Co-authored-by: NBaibaifan <39549453+Baibaifan@users.noreply.github.com>
上级 f5a937eb
...@@ -353,21 +353,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -353,21 +353,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
std::vector<phi::GPUPlace> places; // Only support single card single process
std::vector<phi::GPUPlace> places = {place_};
if (!opts.place_ids.empty()) {
for (auto place_id : opts.place_ids) {
places.emplace_back(place_id);
}
} else if (!used_place_ids_.empty()) {
for (auto place_id : used_place_ids_) {
places.emplace_back(place_id);
}
} else {
auto numGPUs = GetSize();
int place_id = static_cast<int>(rank_ % numGPUs);
places.emplace_back(place_id);
}
std::vector<phi::DenseTensor> barrierTensors; std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size()); barrierTensors.reserve(places.size());
...@@ -375,7 +362,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( ...@@ -375,7 +362,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
platform::CUDADeviceGuard gpuGuard; platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) { for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId()); gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::GPUPlace()); auto dt = full({1}, 0, phi::DataType::FLOAT32, place);
barrierTensors.push_back( barrierTensors.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl())); *std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl()));
} }
......
...@@ -860,9 +860,12 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): ...@@ -860,9 +860,12 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
if in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
if len(tensor_list) == 0:
tensor_shape = list(tensor.shape) tensor_shape = list(tensor.shape)
tensor_shape[0] *= group.nranks tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype) out = paddle.empty(tensor_shape, tensor.dtype)
else:
out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather(tensor, out) task = group.process_group.all_gather(tensor, out)
task.wait() task.wait()
tensor_list.clear() tensor_list.clear()
...@@ -1783,6 +1786,11 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ...@@ -1783,6 +1786,11 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
temp = paddle.concat(in_tensor_list, axis=0) temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list) nranks = len(in_tensor_list)
if in_dygraph_mode(): if in_dygraph_mode():
if len(out_tensor_list) == 0:
tensor_shape = list(in_tensor_list[0].shape)
tensor_shape[0] *= nranks
out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
else:
out = paddle.concat(out_tensor_list, axis=0) out = paddle.concat(out_tensor_list, axis=0)
task = group.process_group.alltoall(temp, out) task = group.process_group.alltoall(temp, out)
task.wait() task.wait()
......
...@@ -185,6 +185,24 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -185,6 +185,24 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(tensor_y, out_2) assert np.array_equal(tensor_y, out_2)
print("test allgather api ok\n") print("test allgather api ok\n")
if pg.rank() == 0:
task = pg.all_gather(tensor_x, tensor_out)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
tensor_out_list = []
task = dist.all_gather(
tensor_out_list, tensor_y, use_calc_stream=False)
paddle.device.cuda.synchronize()
tensor_out = paddle.concat(tensor_out_list)
out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2],
[out_shape[0]])
assert np.array_equal(tensor_x, out_1)
assert np.array_equal(tensor_y, out_2)
print("test allgather api2 ok\n")
# test alltoall # test alltoall
# rank 0 # rank 0
x = np.random.random(self.shape).astype(self.dtype) x = np.random.random(self.shape).astype(self.dtype)
...@@ -219,6 +237,38 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -219,6 +237,38 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(out2_1, raw_tensor_x_2) assert np.array_equal(out2_1, raw_tensor_x_2)
print("test alltoall api ok\n") print("test alltoall api ok\n")
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
out1 = np.random.random(self.shape).astype(self.dtype)
out2 = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
tensor_out1 = paddle.to_tensor(out1)
tensor_out2 = paddle.to_tensor(out2)
raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2],
[self.shape[0]])
raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0],
[self.shape[0] // 2])
if pg.rank() == 0:
task = pg.alltoall(tensor_x, tensor_out1)
task.wait()
# rank 1
else:
in_1, in_2 = paddle.split(tensor_y, 2)
out_1, out_2 = paddle.split(tensor_out2, 2)
out_tensor_list = []
task = dist.alltoall([in_1, in_2], out_tensor_list)
paddle.device.cuda.synchronize()
tensor_out2 = paddle.concat(out_tensor_list)
out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2],
[self.shape[0]])
out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
if pg.rank() == 0:
assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
else:
assert np.array_equal(out2_1, raw_tensor_x_2)
print("test alltoall api2 ok\n")
# test Reduce # test Reduce
# rank 0 # rank 0
x = np.random.random(self.shape).astype(self.dtype) x = np.random.random(self.shape).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册