From 1b9e6790f632df0739a256d2bfb890b769a6041b Mon Sep 17 00:00:00 2001 From: maxhuiy <1508399706@qq.com> Date: Mon, 14 Feb 2022 10:30:49 +0800 Subject: [PATCH] [MLU] add mlu kernel for c_broadcast op (#39470) --- .../collective/c_broadcast_op_mlu.cc | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 paddle/fluid/operators/collective/c_broadcast_op_mlu.cc diff --git a/paddle/fluid/operators/collective/c_broadcast_op_mlu.cc b/paddle/fluid/operators/collective/c_broadcast_op_mlu.cc new file mode 100644 index 0000000000..c4061254ed --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op_mlu.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOPMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_CNCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int numel = x->numel(); + cnclDataType_t dtype = platform::ToCNCLDataType(x->type()); + + int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::CNCLCommContext::Instance().Get(rid, place); + + mluStream stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + int root = ctx.Attr("root"); + if (root == comm->rank()) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnclBcast(reinterpret_cast(const_cast(x->data())), + numel, dtype, root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + + if (out != x) { + framework::TensorCopy( + *static_cast(x), place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + PADDLE_ENFORCE_MLU_SUCCESS(cnclBcast(out->mutable_data(place), numel, + dtype, root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " + << framework::product(out->dims()); + } + + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with MLU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(c_broadcast, ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel, + ops::CBroadcastOPMLUKernel); -- GitLab