未验证 提交 2133b45a 编写于 作者: B Baibaifan 提交者: GitHub

add aclcheck to c_comm_init (#33739)

上级 9b58cbf1
...@@ -69,10 +69,12 @@ class CCommInitOpAscend : public framework::OperatorBase { ...@@ -69,10 +69,12 @@ class CCommInitOpAscend : public framework::OperatorBase {
for (int32_t idx = 0; idx < size; idx++) { for (int32_t idx = 0; idx < size; idx++) {
input[idx] = 1.0; input[idx] = 1.0;
} }
aclrtMalloc(reinterpret_cast<void**>(&buff), size * sizeof(float), PADDLE_ENFORCE_NPU_SUCCESS(aclrtMalloc(reinterpret_cast<void**>(&buff),
ACL_MEM_MALLOC_HUGE_FIRST); size * sizeof(float),
aclrtMemcpy(reinterpret_cast<void*>(buff), size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
input.data(), size * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); PADDLE_ENFORCE_NPU_SUCCESS(aclrtMemcpy(
reinterpret_cast<void*>(buff), size * sizeof(float), input.data(),
size * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE));
VLOG(3) << "Build buff data successful."; VLOG(3) << "Build buff data successful.";
aclrtStream stream = nullptr; aclrtStream stream = nullptr;
...@@ -83,8 +85,8 @@ class CCommInitOpAscend : public framework::OperatorBase { ...@@ -83,8 +85,8 @@ class CCommInitOpAscend : public framework::OperatorBase {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
} }
platform::dynload::HcclBroadcast(buff, size, HCCL_DATA_TYPE_FP32, 0, PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast(
comm->comm(), stream); buff, size, HCCL_DATA_TYPE_FP32, 0, comm->comm(), stream));
VLOG(3) << "Build connection successful."; VLOG(3) << "Build connection successful.";
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册