未验证 提交 6d3f56f3 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add barrier op (#54283)

* [CustomDevice] add barrier op

* update
上级 1b33f5c9
...@@ -644,6 +644,30 @@ class CBroadcastOpCustomDeviceKernel : public framework::OpKernel<T> { ...@@ -644,6 +644,30 @@ class CBroadcastOpCustomDeviceKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class BarrierOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto in = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
auto place = ctx.GetPlace();
int64_t numel = in->numel();
const void* sendbuff = in->data();
void* recvbuff = ctx.device_context().Alloc<T>(out);
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
phi::DeviceManager::CCLAllReduce(place.GetDeviceType(),
const_cast<void*>(sendbuff),
recvbuff,
numel,
phi::ccl::ToCCLDataType(in->dtype()),
phi::ccl::CCLReduceOp::SUM,
comm->comm(),
*(comm->stream()));
}
};
template <typename Context> template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx, void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x, const phi::ExtendedTensor& x,
...@@ -890,6 +914,10 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -890,6 +914,10 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
paddle::operators::CBroadcastOpCustomDeviceKernel<double>, paddle::operators::CBroadcastOpCustomDeviceKernel<double>,
paddle::operators::CBroadcastOpCustomDeviceKernel< paddle::operators::CBroadcastOpCustomDeviceKernel<
paddle::platform::float16>) {} paddle::platform::float16>) {}
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
barrier,
device_type,
paddle::operators::BarrierOpCustomDeviceKernel<int>) {}
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册