未验证 提交 d4f43ad4 编写于 作者: W Wen Sun 提交者: GitHub

refactor: rename xccl files (#49127)

上级 0b36655b
cc_library(
processgroup
process_group
SRCS ProcessGroup.cc
DEPS dense_tensor)
cc_library(
......@@ -9,20 +9,20 @@ cc_library(
cc_library(
eager_reducer
SRCS reducer.cc
DEPS eager_api processgroup process_group_stream phi_api string_helper)
DEPS eager_api process_group process_group_stream phi_api string_helper)
if(WITH_DISTRIBUTE)
cc_library(
processgroup_gloo
SRCS ProcessGroupGloo.cc
process_group_gloo
SRCS process_group_gloo.cc
DEPS phi_api eager_api gloo_wrapper)
endif()
if(WITH_NCCL OR WITH_RCCL)
cc_library(
process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc Common.cc check.cc
DEPS processgroup
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
DEPS process_group
process_group_stream
place
enforce
......@@ -34,23 +34,23 @@ endif()
if(WITH_XPU_BKCL)
cc_library(
processgroup_bkcl
SRCS ProcessGroupBKCL.cc BKCLTools.cc Common.cc
DEPS processgroup place enforce collective_helper device_context
process_group_bkcl
SRCS ProcessGroupBKCL.cc BKCLTools.cc common.cc
DEPS process_group place enforce collective_helper device_context
dense_tensor)
endif()
if(WITH_MPI)
cc_library(
processgroup_mpi
SRCS ProcessGroupMPI.cc MPITools.cc Common.cc
process_group_mpi
SRCS process_group_mpi.cc mpi_tools.cc common.cc
DEPS collective_helper device_context)
endif()
if(WITH_CUSTOM_DEVICE)
cc_library(
processgroup_custom
SRCS ProcessGroupCustom.cc CustomCCLTools.cc Common.cc
DEPS processgroup phi_backends place enforce collective_helper
process_group_custom
SRCS ProcessGroupCustom.cc CustomCCLTools.cc common.cc
DEPS process_group phi_backends place enforce collective_helper
device_context)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/HCCLTools.h"
#include "paddle/fluid/distributed/collective/Types.h"
namespace paddle {
namespace distributed {
HcclReduceOp ToHCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, HcclReduceOp> red_type = {
{ReduceOp::MIN, HCCL_REDUCE_MIN},
{ReduceOp::MAX, HCCL_REDUCE_MAX},
{ReduceOp::SUM, HCCL_REDUCE_SUM},
{ReduceOp::PRODUCT, HCCL_REDUCE_PROD},
};
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(
it != red_type.end(),
true,
platform::errors::InvalidArgument("Invalid hccl reduction. "
"Must be Min | Max | Prod | Sum"));
return it->second;
}
std::string SerializeHCCLUniqueId(const HcclRootInfo& hcclID) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&hcclID);
std::ostringstream oss;
for (size_t i = 0; i < sizeof(hcclID); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <error.h>
#include <string>
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace distributed {
class NPUEventManager {
public:
NPUEventManager() = default;
~NPUEventManager() {
if (is_created_) {
platform::NPUDeviceGuard guard(device_index_);
platform::NPUEventDestroy(event_);
}
}
NPUEventManager(const NPUEventManager&) = delete;
NPUEventManager& operator=(const NPUEventManager&) = delete;
NPUEventManager(NPUEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
NPUEventManager& operator=(NPUEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
aclrtEvent GetRawNPUEvent() const { return event_; }
void Record(const paddle::platform::NPUDeviceContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"NPUDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::NPUDeviceGuard guard(device_index_);
platform::NPUEventRecord(event_, ctx.stream());
}
bool Query() const {
aclrtEventStatus status = ACL_EVENT_STATUS_COMPLETE;
platform::NPUEventQuery(event_, &status);
if (status == ACL_EVENT_STATUS_COMPLETE) {
return true;
}
return false;
}
void Block(const paddle::platform::NPUDeviceContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::NPUDeviceGuard guard(device_index_);
platform::NPUStreamWaitEvent(ctx.stream(), event_);
}
}
private:
bool is_created_{false};
aclrtEvent event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::NPUDeviceGuard guard(device_index);
platform::NPUEventCreate(&event_);
is_created_ = true;
}
};
class HCCLCommManager {
public:
explicit HCCLCommManager(HcclComm hcclComm) : hccl_comm_(hcclComm) {}
HCCLCommManager() : HCCLCommManager(nullptr) {}
~HCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (hccl_comm_) {
platform::dynload::HcclCommDestroy(hccl_comm_);
}
}
static std::shared_ptr<HCCLCommManager> Create(int num_ranks,
int rank,
HcclRootInfo* comm_id,
HcclComm hccl_comm) {
auto hccl_manager = std::make_shared<HCCLCommManager>();
auto ret = platform::dynload::HcclCommInitRootInfo(
num_ranks, comm_id, rank, &hccl_comm);
using __NPU_STATUS_TYPE__ = decltype(ret);
constexpr auto __success_type__ =
platform::details::NPUStatusType<__NPU_STATUS_TYPE__>::kSuccess;
if (UNLIKELY(ret != __success_type__)) {
VLOG(0) << "Error: create hccl_id error.";
exit(-1);
}
hccl_manager->hccl_id_ = comm_id;
hccl_manager->rank_ = rank;
hccl_manager->hccl_comm_ = hccl_comm;
return hccl_manager;
}
HcclRootInfo* GetHcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return hccl_id_;
}
HcclComm GetHcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return hccl_comm_;
}
HCCLCommManager(const HCCLCommManager&) = delete;
HCCLCommManager& operator=(const HCCLCommManager&) = delete;
HCCLCommManager& operator=(HCCLCommManager&& other) = delete;
HCCLCommManager(HCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(hccl_comm_, other.hccl_comm_);
}
protected:
HcclComm hccl_comm_;
HcclRootInfo* hccl_id_;
int rank_;
mutable std::mutex mutex_;
};
HcclReduceOp ToHCCLRedType(ReduceOp reduction);
std::string SerializeHCCLUniqueId(const HcclRootInfo& hcclID);
} // namespace distributed
} // namespace paddle
......@@ -15,7 +15,7 @@
#include "paddle/fluid/distributed/collective/ProcessGroupBKCL.h"
#include "paddle/fluid/distributed/collective/BKCLTools.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device_context.h"
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/CustomCCLTools.h"
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/common.h"
namespace paddle {
namespace distributed {
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/MPITools.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/mpi_tools.h"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/distributed/collective/common.h"
namespace paddle {
namespace distributed {
......
......@@ -28,8 +28,8 @@
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -400,6 +400,15 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......
......@@ -120,6 +120,11 @@ class ProcessGroupGloo : public ProcessGroup {
int64_t /*numel*/, // for compatibility, no use now
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroupMPI.h"
#include "paddle/fluid/distributed/collective/process_group_mpi.h"
#include <chrono>
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/common.h"
constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle {
......
......@@ -30,7 +30,7 @@
#include "paddle/fluid/platform/device_context.h"
#if defined(PADDLE_WITH_MPI)
#include "paddle/fluid/distributed/collective/MPITools.h"
#include "paddle/fluid/distributed/collective/mpi_tools.h"
#endif
namespace paddle {
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/check.h"
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
......
......@@ -155,21 +155,21 @@ if(WITH_CUSTOM_DEVICE)
endif()
if(WITH_PYTHON)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup eager_reducer)
set(PYBIND_DEPS ${PYBIND_DEPS} process_group eager_reducer)
if(WITH_NCCL OR WITH_RCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_nccl)
endif()
if(WITH_XPU_BKCL)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_bkcl)
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_bkcl)
endif()
if(WITH_GLOO)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo)
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_gloo)
endif()
if(WITH_MPI)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_mpi)
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_mpi)
endif()
if(WITH_CUSTOM_DEVICE)
set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_custom)
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_custom)
endif()
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
......
......@@ -38,7 +38,7 @@ limitations under the License. */
#endif
#if defined(PADDLE_WITH_MPI)
#include "paddle/fluid/distributed/collective/ProcessGroupMPI.h"
#include "paddle/fluid/distributed/collective/process_group_mpi.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
......@@ -46,7 +46,7 @@ limitations under the License. */
#endif
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif
......
......@@ -65,12 +65,12 @@ if(WITH_CUSTOM_DEVICE)
DEPS phi_capi)
endif()
set(COMM_UTILS_DEPS processgroup)
set(COMM_UTILS_DEPS process_group)
if(WITH_NCCL OR WITH_RCCL)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} process_group_nccl)
endif()
if(WITH_CUSTOM_DEVICE)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_custom)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} process_group_custom)
endif()
cc_library(
processgroup_comm_utils
......
......@@ -78,7 +78,7 @@ set(COMMON_KERNEL_DEPS
gpc
utf8proc)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册