未验证 提交 6786c012 编写于 作者: J jameszhang 提交者: GitHub

[kunlun] support reduce_scatter (#50792)

* [kunlun] support reduce_scatter

* uncomment unittest

* update xccl to 1.0.10
上级 2eeaaa7d
......@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_BASE_DATE "20230220")
set(XPU_XCCL_BASE_VERSION "1.0.9")
set(XPU_XCCL_BASE_VERSION "1.0.10")
if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE
......
......@@ -367,6 +367,34 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_reduce_scatter(
comm,
input.data(),
output->data(),
output->numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
},
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id,
......
......@@ -115,6 +115,13 @@ class ProcessGroupBKCL : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
......
......@@ -175,6 +175,34 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(tensor_y, old_tensor_y)
sys.stdout.write("rank {}: test reduce sum api ok\n".format(pg.rank()))
# test reduce_scatter
in_shape = list(self.shape)
in_shape[0] *= 2
x = np.random.random(in_shape).astype(self.dtype)
y = np.random.random(in_shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
need_result = tensor_x + tensor_y
need_result0 = paddle.slice(need_result, [0], [0], [self.shape[0]])
need_result1 = paddle.slice(
need_result, [0], [self.shape[0]], [in_shape[0]]
)
out = np.random.random(self.shape).astype(self.dtype)
tensor_out = paddle.to_tensor(out)
if pg.rank() == 0:
task = dist.reduce_scatter(tensor_out, tensor_x, sync_op=True)
else:
task = dist.reduce_scatter(tensor_out, tensor_y, sync_op=False)
task.wait()
paddle.device.xpu.synchronize()
if pg.rank() == 0:
assert np.array_equal(need_result0, tensor_out)
else:
assert np.array_equal(need_result1, tensor_out)
sys.stdout.write(
"rank {}: test reduce_scatter sum api ok\n".format(pg.rank())
)
# test send async api
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册