未验证 提交 5f995d3f 编写于 作者: J james 提交者: GitHub

processgroup bkcl support reduce (#48232)

Note: this is a temporary solution, should be replaced once reduce kernel
is natively supported on KL2
上级 f254d0a0
......@@ -260,6 +260,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
phi::DenseTensor output_t(*output);
const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
switch (input.dtype()) {
case phi::DataType::FLOAT32:
calc_ctx->template Alloc<float>(&output_t);
break;
case phi::DataType::FLOAT16:
calc_ctx->template Alloc<float16>(&output_t);
break;
case phi::DataType::INT32:
calc_ctx->template Alloc<int>(&output_t);
break;
default:
VLOG(0) << "Error: type " << input.dtype() << " not supported for "
<< GetBackendName();
break;
}
int ret =
bkcl_all_reduce(comm,
input.data(),
output_t.data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
if (rank_ == opts.root_rank) {
*output = output_t;
}
return ret;
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id,
......
......@@ -107,6 +107,12 @@ class ProcessGroupBKCL : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
......
......@@ -168,6 +168,27 @@ class TestProcessGroupFp32(unittest.TestCase):
"rank {}: test allgather api2 ok\n".format(pg.rank())
)
# test Reduce
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
sum_result = tensor_x + tensor_y
if pg.rank() == 0:
task = dist.reduce(tensor_x, 0, sync_op=True)
paddle.device.xpu.synchronize()
# rank 1
else:
task = dist.reduce(tensor_y, 0, sync_op=False)
task.wait()
paddle.device.xpu.synchronize()
if pg.rank() == 0:
assert np.array_equal(tensor_x, sum_result)
sys.stdout.write(
"rank {}: test reduce sum api ok\n".format(pg.rank())
)
class TestProcessGroupFp16(TestProcessGroupFp32):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册