未验证 提交 82c498d9 编写于 作者: M Ma, Guokai 提交者: GitHub

Fix deadlock when SHM based allreduce spin too fast (#4048)

* Fix deadlock when allreduce spin too fast

* Change state to enum to increase readability

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 7f90ef4b
......@@ -17,6 +17,15 @@
#include <iostream>
#include <oneapi/ccl.hpp>
// states for collectives
enum coll_state {
coll_begin = 0,
// coll states for naive allreduce
coll_allreduce_naive__copy_in_done, // this state is for rank != 0
coll_allreduce_naive__reduce_done, // this state is for rank == 0
coll_allreduce_naive__copy_out_done, // this state is for rank != 0
};
// SHM building blocks
struct SharedData {
const char* name;
......@@ -63,19 +72,27 @@ void shared_close(SharedData* data)
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
SharedData allreduce_buffer;
struct allreduce_workspace {
int state;
enum coll_state state;
char buffer[MAX_BUF_SIZE];
};
struct allreduce_workspace* workspace;
void wait_buffer_state_until(int index, int state)
void wait_buffer_state_until(int index, enum coll_state state)
{
volatile int* state_ptr = &(workspace[index].state);
volatile enum coll_state* state_ptr = &(workspace[index].state);
while (*state_ptr != state)
;
}
void wait_buffer_state_until_not(int index, enum coll_state state)
{
volatile enum coll_state* state_ptr = &(workspace[index].state);
while (*state_ptr == state)
;
}
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_bf16_to_fp32(const __m256i src)
{
......@@ -313,7 +330,7 @@ void initialize(int size, int rank, torch::Tensor& kvs_data)
workspace,
size * sizeof(struct allreduce_workspace));
workspace = (struct allreduce_workspace*)allreduce_buffer.bytes;
for (int i = 0; i < size; i++) { workspace[i].state = 0; }
for (int i = 0; i < size; i++) { workspace[i].state = coll_begin; }
}
CCLCHECK(ccl::barrier(_get_comm_from_group()).wait());
if (rank != 0) {
......@@ -506,33 +523,38 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group,
memcpy(workspace[world_rank].buffer, data_ptr, data_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = 1;
workspace[world_rank].state = coll_allreduce_naive__copy_in_done;
if (world_rank == 0) {
// compute allreduce result on rank 0
for (int i = 1; i < world_size; i++) {
// wait until the other rank copy the buffer
wait_buffer_state_until(i, 1);
wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done);
}
reduce_all_buffers(workspace, numel, data.scalar_type(), world_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = 2;
workspace[world_rank].state = coll_allreduce_naive__reduce_done;
memcpy(data_ptr, workspace[0].buffer, data_size);
}
if (world_rank != 0) {
wait_buffer_state_until(0, 2);
wait_buffer_state_until(0, coll_allreduce_naive__reduce_done);
memcpy(data_ptr, workspace[0].buffer, data_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = 2;
workspace[world_rank].state = coll_allreduce_naive__copy_out_done;
}
if (world_rank == 0) {
for (int i = 1; i < world_size; i++) { wait_buffer_state_until(i, 2); }
for (int i = 1; i < world_size; i++) {
wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done);
}
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank].state = 0;
workspace[world_rank].state = coll_begin;
}
if (world_rank != 0) {
wait_buffer_state_until(0, 0);
workspace[world_rank].state = 0;
// if rank 0 spin too fast it could be in state 1 of next allreduce
// in this case wait_buffer_state_until(0, 0) may cause deadlock
// what we are certain is when rank 0 finishes the state won't be 2
wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done);
workspace[world_rank].state = coll_begin;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册