未验证 提交 1bd7a143 编写于 作者: mhhhh1's avatar mhhhh1 提交者: GitHub

[MLU] add mlu kernel for c_comm_init op (#39364)

上级 c35b4b8e
...@@ -47,6 +47,10 @@ ...@@ -47,6 +47,10 @@
#include "xpu/bkcl.h" #include "xpu/bkcl.h"
#endif #endif
#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#endif
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
class SelectedRows; class SelectedRows;
...@@ -181,6 +185,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -181,6 +185,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId, platform::BKCLCommunicator, BKCLUniqueId, platform::BKCLCommunicator,
#endif
#if defined(PADDLE_WITH_CNCL)
cnclCliqueId,
#endif #endif
int, float, Vocab>; int, float, Vocab>;
template <typename T> template <typename T>
......
...@@ -20,12 +20,15 @@ limitations under the License. */ ...@@ -20,12 +20,15 @@ limitations under the License. */
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h" #include "xpu/bkcl.h"
#endif #endif
#if defined(PADDLE_WITH_CNCL)
#include <cncl.h>
#endif
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#endif #endif
...@@ -56,18 +59,23 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -56,18 +59,23 @@ class CCommInitOp : public framework::OperatorBase {
using UniqueId = BKCLUniqueId; using UniqueId = BKCLUniqueId;
using Place = platform::XPUPlace; using Place = platform::XPUPlace;
using CommContext = platform::BKCLCommContext; using CommContext = platform::BKCLCommContext;
#elif defined(PADDLE_WITH_CNCL)
using UniqueId = cnclCliqueId;
using Place = platform::MLUPlace;
using CommContext = platform::CNCLCommContext;
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should be compiled with GPU or XPU.")); "PaddlePaddle should be compiled with GPU or XPU or MLU."));
#endif #endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place) || platform::is_xpu_place(place), true, platform::is_gpu_place(place) || platform::is_xpu_place(place) ||
platform::errors::PreconditionNotMet( platform::is_mlu_place(place),
"CCommInitOp can run on gpu or xpu place only.")); true, platform::errors::PreconditionNotMet(
"CCommInitOp can run on gpu or xpu or mlu place only."));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL)
auto var = scope.FindVar(Input("X")); auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty.")); var, platform::errors::InvalidArgument("Input con not be empty."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册