diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index d5ae2f84b4ed1deda9e10d966ef4ec8e4bb15356..729e840e89d3e9b437e1654074687bb64d2df450 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -644,6 +644,30 @@ class CBroadcastOpCustomDeviceKernel : public framework::OpKernel { } }; +template +class BarrierOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + int64_t numel = in->numel(); + const void* sendbuff = in->data(); + void* recvbuff = ctx.device_context().Alloc(out); + int rid = ctx.Attr("ring_id"); + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + phi::DeviceManager::CCLAllReduce(place.GetDeviceType(), + const_cast(sendbuff), + recvbuff, + numel, + phi::ccl::ToCCLDataType(in->dtype()), + phi::ccl::CCLReduceOp::SUM, + comm->comm(), + *(comm->stream())); + } +}; + template void FeedDenseTensorKernel(const Context& dev_ctx, const phi::ExtendedTensor& x, @@ -890,6 +914,10 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { paddle::operators::CBroadcastOpCustomDeviceKernel, paddle::operators::CBroadcastOpCustomDeviceKernel< paddle::platform::float16>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + barrier, + device_type, + paddle::operators::BarrierOpCustomDeviceKernel) {} #endif }