diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc index 75953dc0b4289f0dfbbb2187b038d2fa4affb073..ff39196b92b3a046eb49507be94155cb0bdeca8c 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc @@ -260,6 +260,57 @@ std::shared_ptr ProcessGroupBKCL::AllGather( use_calc_stream); } +std::shared_ptr ProcessGroupBKCL::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + phi::DenseTensor output_t(*output); + const auto& place = input.place(); + auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + switch (input.dtype()) { + case phi::DataType::FLOAT32: + calc_ctx->template Alloc(&output_t); + break; + case phi::DataType::FLOAT16: + calc_ctx->template Alloc(&output_t); + break; + case phi::DataType::INT32: + calc_ctx->template Alloc(&output_t); + break; + default: + VLOG(0) << "Error: type " << input.dtype() << " not supported for " + << GetBackendName(); + break; + } + int ret = + bkcl_all_reduce(comm, + input.data(), + output_t.data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + if (rank_ == opts.root_rank) { + *output = output_t; + } + return ret; + }, + CommType::ALLREDUCE, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupBKCL::Barrier( const BarrierOptions& opts) { PADDLE_ENFORCE_GE(opts.device_id, diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.h b/paddle/fluid/distributed/collective/ProcessGroupBKCL.h index b4a47e83fdd8a07c063403bccd4d1f34add00bf7..79d97609d9274e86349c1e06e8166112ecaca071 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupBKCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.h @@ -107,6 +107,12 @@ class ProcessGroupBKCL : public ProcessGroupStream { bool sync_op, bool use_calc_stream) override; + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; diff --git a/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py b/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py index 2317e38cb28d0d7a16d8e92ab844645adf0f5f64..a106c630f3634c776852d56c15317647ae710dd3 100644 --- a/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py +++ b/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py @@ -168,6 +168,27 @@ class TestProcessGroupFp32(unittest.TestCase): "rank {}: test allgather api2 ok\n".format(pg.rank()) ) + # test Reduce + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + sum_result = tensor_x + tensor_y + if pg.rank() == 0: + task = dist.reduce(tensor_x, 0, sync_op=True) + paddle.device.xpu.synchronize() + # rank 1 + else: + task = dist.reduce(tensor_y, 0, sync_op=False) + task.wait() + paddle.device.xpu.synchronize() + if pg.rank() == 0: + assert np.array_equal(tensor_x, sum_result) + sys.stdout.write( + "rank {}: test reduce sum api ok\n".format(pg.rank()) + ) + class TestProcessGroupFp16(TestProcessGroupFp32): def setUp(self):