未验证 提交 1bd7a143 编写于 作者: mhhhh1's avatar mhhhh1 提交者: GitHub

[MLU] add mlu kernel for c_comm_init op (#39364)

上级 c35b4b8e
......@@ -47,6 +47,10 @@
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#endif
namespace pten {
class DenseTensor;
class SelectedRows;
......@@ -181,6 +185,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId, platform::BKCLCommunicator,
#endif
#if defined(PADDLE_WITH_CNCL)
cnclCliqueId,
#endif
int, float, Vocab>;
template <typename T>
......
......@@ -20,12 +20,15 @@ limitations under the License. */
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#endif
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
......@@ -56,18 +59,23 @@ class CCommInitOp : public framework::OperatorBase {
using UniqueId = BKCLUniqueId;
using Place = platform::XPUPlace;
using CommContext = platform::BKCLCommContext;
#elif defined(PADDLE_WITH_CNCL)
using UniqueId = cnclCliqueId;
using Place = platform::MLUPlace;
using CommContext = platform::CNCLCommContext;
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU or XPU."));
"PaddlePaddle should be compiled with GPU or XPU or MLU."));
#endif
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place) || platform::is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu place only."));
platform::is_gpu_place(place) || platform::is_xpu_place(place) ||
platform::is_mlu_place(place),
true, platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu or mlu place only."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL)
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册