未验证 提交 73583f86 编写于 作者: L lilong12 提交者: GitHub

add the implementation of process group for hccl (#40228)

* add pg_hccl
上级 7024ade7
......@@ -7,3 +7,6 @@ cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup)
if(WITH_NCCL)
cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api)
endif()
if(WITH_ASCEND_CL)
cc_library(processgroup_hccl SRCS ProcessGroupHCCL.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <error.h>
#include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class NPUEventManager {
public:
NPUEventManager() = default;
~NPUEventManager() {
if (is_created_) {
platform::NPUDeviceGuard guard(device_index_);
platform::NPUEventDestroy(event_);
}
}
NPUEventManager(const NPUEventManager&) = delete;
NPUEventManager& operator=(const NPUEventManager&) = delete;
NPUEventManager(NPUEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
NPUEventManager& operator=(NPUEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
aclrtEvent GetRawNPUEvent() const { return event_; }
void Record(const paddle::platform::NPUDeviceContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index, device_index_,
platform::errors::PreconditionNotMet(
"NPUDeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
platform::NPUDeviceGuard guard(device_index_);
platform::NPUEventRecord(event_, ctx.stream());
}
bool Query() const {
aclrtEventStatus status = ACL_EVENT_STATUS_COMPLETE;
platform::NPUEventQuery(event_, &status);
if (status == ACL_EVENT_STATUS_COMPLETE) {
return true;
}
return false;
}
void Block(const paddle::platform::NPUDeviceContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index, device_index_,
platform::errors::PreconditionNotMet(
"CUDADeviceContext's device %d does not match"
"Event's device %d",
device_index, device_index_));
platform::NPUDeviceGuard guard(device_index_);
platform::NPUStreamWaitEvent(ctx.stream(), event_);
}
}
private:
bool is_created_{false};
aclrtEvent event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::NPUDeviceGuard guard(device_index);
platform::NPUEventCreate(&event_);
is_created_ = true;
}
};
class HCCLCommManager {
public:
explicit HCCLCommManager(HcclComm hcclComm) : hccl_comm_(hcclComm) {}
HCCLCommManager() : HCCLCommManager(nullptr) {}
~HCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (hccl_comm_) {
platform::dynload::HcclCommDestroy(hccl_comm_);
}
}
static std::shared_ptr<HCCLCommManager> Create(int num_ranks, int rank,
HcclRootInfo* comm_id,
HcclComm hccl_comm) {
auto hccl_manager = std::make_shared<HCCLCommManager>();
auto ret = platform::dynload::HcclCommInitRootInfo(num_ranks, comm_id, rank,
&hccl_comm);
using __NPU_STATUS_TYPE__ = decltype(ret);
constexpr auto __success_type__ =
platform::details::NPUStatusType<__NPU_STATUS_TYPE__>::kSuccess;
if (UNLIKELY(ret != __success_type__)) {
VLOG(0) << "Error: create hccl_id error.";
exit(-1);
}
hccl_manager->hccl_id_ = comm_id;
hccl_manager->rank_ = rank;
hccl_manager->hccl_comm_ = hccl_comm;
return hccl_manager;
}
HcclRootInfo* GetHcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return hccl_id_;
}
HcclComm GetHcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return hccl_comm_;
}
HCCLCommManager(const HCCLCommManager&) = delete;
HCCLCommManager& operator=(const HCCLCommManager&) = delete;
HCCLCommManager& operator=(HCCLCommManager&& other) = delete;
HCCLCommManager(HCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(hccl_comm_, other.hccl_comm_);
}
protected:
HcclComm hccl_comm_;
HcclRootInfo* hccl_id_;
int rank_;
mutable std::mutex mutex_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
DECLARE_bool(hccl_blocking_wait);
// DECLARE_bool(use_stream_safe_npu_allocator);
constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle {
namespace distributed {
static HcclReduceOp ToHCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, HcclReduceOp> red_type = {
{ReduceOp::MIN, HCCL_REDUCE_MIN},
{ReduceOp::MAX, HCCL_REDUCE_MAX},
{ReduceOp::SUM, HCCL_REDUCE_SUM},
{ReduceOp::PRODUCT, HCCL_REDUCE_PROD},
};
auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(
it != red_type.end(), true,
platform::errors::InvalidArgument("Invalid hccl reduction. "
"Must be Min | Max | Prod | Sum"));
return it->second;
}
std::string SerializeHCCLUniqueId(const HcclRootInfo& hcclID) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&hcclID);
std::ostringstream oss;
for (size_t i = 0; i < sizeof(hcclID); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
// Get the list of devices from list of tensors
std::vector<Place> GetPlaceList(const std::vector<Tensor>& tensors) {
std::vector<Place> places;
places.reserve(tensors.size());
for (auto& tensor : tensors) {
places.push_back(tensor.inner_place());
}
return places;
}
// Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places) {
std::string placeList;
for (auto& place : places) {
std::stringstream tmp;
tmp << place;
if (placeList.empty()) {
placeList += tmp.str();
} else {
placeList += "," + tmp.str();
}
}
return placeList;
}
// bool CheckTensorsInNPUPlace(const std::vector<Tensor>& tensors) {
// return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
// return t.place() == platform::DeviceType::NPU;
// });
// }
void SyncDefaultStream(
const std::vector<Place>& places,
std::vector<NPUEventManager>& hcclEvents, // NOLINT
std::vector<std::unique_ptr<NPUDeviceContext>>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
hcclEvents[i].Record(*dev_ctx[i]);
hcclEvents[i].Block(*default_ctx);
}
}
std::shared_ptr<ProcessGroupHCCL::HCCLTask> ProcessGroupHCCL::CreateTask(
std::vector<Place> places, int rank, CommType comm_type,
const std::vector<Tensor>& inputs) {
return std::make_shared<ProcessGroupHCCL::HCCLTask>(places, rank, comm_type,
inputs);
}
ProcessGroupHCCL::HCCLTask::HCCLTask(const std::vector<Place>& places, int rank,
CommType CommType,
const std::vector<Tensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
hcclComms_.resize(places.size());
}
ProcessGroupHCCL::HCCLTask::~HCCLTask() {}
void ProcessGroupHCCL::HCCLTask::SetOutputs(
std::vector<Tensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<Tensor>>(outputs);
}
void ProcessGroupHCCL::HCCLTask::SynchronizeStreams() {
for (size_t i = 0; i < places_.size(); ++i) {
auto* default_ctx = static_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(places_[i]));
platform::NPUStreamWaitEvent(default_ctx->stream(),
control_events_[i].GetRawNPUEvent());
}
}
bool ProcessGroupHCCL::HCCLTask::IsCompleted() {
for (size_t i = 0; i < places_.size(); ++i) {
if (!control_events_[i].Query()) {
return false;
}
}
return true;
}
// TODO(sandyhouse): Add timeout for wait, now timeout unused
bool ProcessGroupHCCL::HCCLTask::Wait(std::chrono::milliseconds timeout) {
SynchronizeStreams();
if (FLAGS_hccl_blocking_wait) {
// NOTE(sandyhouse): It will block host for sync
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
return true;
}
// Same as Wait
void ProcessGroupHCCL::HCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupHCCL::ProcessGroupHCCL(const std::shared_ptr<Store>& store,
int rank, int size)
: ProcessGroup(rank, size), store_(store) {}
void ProcessGroupHCCL::BroadcastUniqueHCCLID(
std::vector<HcclRootInfo>& hccl_ids) { // NOLINT
if (rank_ == 0) {
for (size_t i = 0; i < hccl_ids.size(); i++) {
auto key = "ProcessGroupHCCL/hccl_ids/" + std::to_string(i);
auto hccl_id = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&hccl_ids[i]),
reinterpret_cast<uint8_t*>(&hccl_ids[i]) + sizeof(HcclRootInfo));
store_->set(key, hccl_id);
}
} else {
for (size_t i = 0; i < hccl_ids.size(); i++) {
auto key = "ProcessGroupHCCL/hccl_ids/" + std::to_string(i);
auto ret = store_->get(key);
std::memcpy(&hccl_ids[i], ret.data(), ret.size());
}
}
}
// create HCCLManager cache for places_key
void ProcessGroupHCCL::CreateHCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), false,
platform::errors::PreconditionNotMet(
"Not able to create/get the HCCL Communicator since "
"the NPU place are not known"));
std::vector<std::shared_ptr<HCCLCommManager>> hccl_comms;
hccl_comms.resize(places.size());
// using vector just for broadcast
std::vector<HcclRootInfo> hccl_ids;
hccl_ids.resize(1);
auto& hccl_id = hccl_ids.front();
if (rank_ == 0) {
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclGetRootInfo(&hccl_id));
}
BroadcastUniqueHCCLID(hccl_ids);
VLOG(3) << "init hccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key
<< ", hccl uniqueid: " << SerializeHCCLUniqueId(hccl_id);
std::vector<std::unique_ptr<NPUDeviceContext>> dev_ctx;
dev_ctx.resize(places.size());
std::unique_ptr<HcclComm[]> comms(new HcclComm[places.size()]);
for (size_t i = 0; i < places.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
hccl_comms[i] = HCCLCommManager::Create(GetSize(), GetRank(), &hccl_id,
comms.get() + i);
dev_ctx[i].reset(new NPUDeviceContext(places[i]));
}
std::vector<NPUEventManager> events;
events.resize(places.size());
// These caches will be useful to process sync/wait/communicate
places_to_events_.emplace(places_key, std::move(events));
places_to_hcclcomm_.emplace(places_key, std::move(hccl_comms));
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
std::vector<Tensor>& inputs, std::vector<Tensor>& outputs, Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_hcclcomm_.find(key) == places_to_hcclcomm_.end()) {
CreateHCCLManagerCache(key, places);
}
}
auto& hccl_comms = places_to_hcclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < inputs.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for (size_t i = 0; i < inputs.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
const auto& hccl_stream = places_to_ctx_[key][i]->stream();
fn(inputs[i], outputs[i], hccl_comms[i]->GetHcclComm(), hccl_stream);
}
for (size_t i = 0; i < inputs.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::PointToPoint(
std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_hcclcomm_.find(key) == places_to_hcclcomm_.end()) {
CreateHCCLManagerCache(key, places);
}
}
auto& hccl_comms = places_to_hcclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, tensors);
// construct uninitialize guard for device
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < tensors.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for (size_t i = 0; i < tensors.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
const auto& hccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], hccl_comms[i]->GetHcclComm(), hccl_stream, dst_rank);
}
for (size_t i = 0; i < tensors.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId());
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// NPUPlace."));
return Collective(
tensors, tensors,
[&](const Tensor& input, Tensor& output, HcclComm comm,
const aclrtStream& stream) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::HcclAllReduce(
input_tensor->data(), output_tensor->data(), input_tensor->numel(),
platform::ToHCCLDataType(input.type()),
ToHCCLRedType(opts.reduce_op), comm, stream);
},
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Broadcast(
std::vector<Tensor>& tensors, const BroadcastOptions& opts) {
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// CudaPlace."));
return Collective(
tensors, tensors,
[&](Tensor& input, Tensor& output, HcclComm comm,
const aclrtStream& stream) {
const auto root = opts.source_rank * tensors.size() + opts.source_root;
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::HcclBroadcast(
input_tensor->data(), input_tensor->numel(),
platform::ToHCCLDataType(input.type()), root, comm, stream);
},
CommType::BROADCAST);
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/distributed/collective/HCCLTools.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
constexpr const char* HCCL_BACKEND_NAME = "HCCL";
namespace paddle {
namespace distributed {
using Place = paddle::platform::Place;
using NPUStream = platform::stream::NPUStream;
using NPUDeviceContext = paddle::platform::NPUDeviceContext;
class ProcessGroupHCCL : public ProcessGroup {
public:
class HCCLTask : public ProcessGroup::Task,
public std::enable_shared_from_this<HCCLTask> {
public:
HCCLTask(const std::vector<Place>& places, int rank, CommType CommType,
const std::vector<Tensor>& inputs);
bool IsCompleted();
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
void SetOutputs(std::vector<Tensor>& outputs); // NOLINT
virtual ~HCCLTask();
std::vector<NPUEventManager> control_events_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<HCCLCommManager>> hcclComms_;
std::shared_ptr<std::vector<Tensor>> outputs_;
private:
};
ProcessGroupHCCL(const std::shared_ptr<Store>& store, int rank, int size);
const std::string GetBackendName() const override {
return std::string(HCCL_BACKEND_NAME);
}
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(std::vector<Tensor>& tensors,
int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(std::vector<Tensor>& tensors,
int src_rank) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in, std::vector<Tensor>& out) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors, const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(std::vector<Tensor>& in_tensors,
std::vector<Tensor>& out_tensors,
const ScatterOptions&) override;
protected:
virtual std::shared_ptr<ProcessGroupHCCL::HCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
const std::vector<Tensor>& inputs);
std::shared_ptr<Store> store_;
std::shared_ptr<HCCLCommManager> hccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<HCCLCommManager>>>
places_to_hcclcomm_;
std::unordered_map<std::string, std::vector<NPUEventManager>>
places_to_events_;
std::unordered_map<std::string,
std::vector<std::unique_ptr<NPUDeviceContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
private:
void BcastHCCLId(std::vector<HcclRootInfo>& hccl_ids, int root, // NOLINT
int server_fd);
void BroadcastUniqueHCCLID(std::vector<HcclRootInfo>& hccl_ids); // NOLINT
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<Tensor>& inputs, // NOLINT
std::vector<Tensor>& outputs, // NOLINT
Fn fn, CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<Tensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type);
void CreateHCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
};
} // namespace distributed
} // namespace paddle
......@@ -53,6 +53,23 @@ inline HcclDataType ToHCCLDataType(framework::proto::VarType::Type type) {
}
}
inline HcclDataType ToHCCLDataType(experimental::DataType type) {
if (type == experimental::DataType::FLOAT32) {
return HCCL_DATA_TYPE_FP32;
} else if (type == experimental::DataType::FLOAT16) {
return HCCL_DATA_TYPE_FP16;
} else if (type == experimental::DataType::INT64) {
return HCCL_DATA_TYPE_INT64;
} else if (type == experimental::DataType::INT32) {
return HCCL_DATA_TYPE_INT32;
} else if (type == experimental::DataType::INT8) {
return HCCL_DATA_TYPE_INT8;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in hccl is not supported."));
}
}
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
......
......@@ -88,6 +88,9 @@ if(NOT ON_INFER)
if (WITH_GLOO)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo)
endif()
if(WITH_ASCEND)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_hccl)
endif()
set(PYBIND_SRCS ${PYBIND_SRCS} distributed_py.cc)
endif()
......
......@@ -35,6 +35,10 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#endif
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
......@@ -201,6 +205,14 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>());
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
py::class_<distributed::ProcessGroupHCCL,
std::shared_ptr<distributed::ProcessGroupHCCL>>(
*m, "ProcessGroupHCCL", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int>(),
py::call_guard<py::gil_scoped_release>());
#endif
py::class_<distributed::ProcessGroup::Task,
std::shared_ptr<distributed::ProcessGroup::Task>>(*m, "task")
.def("is_completed", &distributed::ProcessGroup::Task::IsCompleted)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import random
import numpy as np
import os
import shutil
import paddle
from paddle.fluid import core
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
def init_process_group(strategy=None):
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks)
pg_group = core.ProcessGroupHCCL(store, rank, nranks)
return pg_group
class TestProcessGroupFp32(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()
def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)
def test_create_process_group_nccl(self):
with _test_eager_guard():
paddle.set_device('npu:%d' %
paddle.distributed.ParallelEnv().dev_id)
pg = init_process_group()
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
sum_result = tensor_x + tensor_y
if pg.rank() == 0:
task = pg.allreduce(tensor_x)
task.wait()
assert np.array_equal(tensor_x, sum_result)
else:
task = pg.allreduce(tensor_y)
task.wait()
assert np.array_equal(tensor_y, sum_result)
print("test allreduce sum api ok")
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
max_result = paddle.maximum(tensor_x, tensor_y)
if pg.rank() == 0:
task = pg.allreduce(tensor_x, core.ReduceOp.MAX)
task.wait()
assert np.array_equal(tensor_x, max_result)
else:
task = pg.allreduce(tensor_y, core.ReduceOp.MAX)
task.wait()
assert np.array_equal(tensor_y, max_result)
print("test allreduce max api ok")
# test broadcast
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random(self.shape).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
broadcast_result = paddle.assign(tensor_x)
if pg.rank() == 0:
task = pg.broadcast(tensor_x, 0)
task.synchronize()
paddle.device.cuda.synchronize()
assert task.is_completed()
assert np.array_equal(broadcast_result, tensor_x)
else:
task = pg.broadcast(tensor_y, 0)
task.synchronize()
paddle.device.cuda.synchronize()
assert task.is_completed()
assert np.array_equal(broadcast_result, tensor_y)
print("test broadcast api ok")
# test barrier
# rank 0
if pg.rank() == 0:
task = pg.barrier()
task.wait()
# rank 1
else:
task = pg.barrier()
task.wait()
print("test barrier api ok\n")
exit(0)
# test allgather
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
out_shape = list(self.shape)
out_shape[0] *= 2
out = np.random.random(out_shape).astype(self.dtype)
tensor_out = paddle.to_tensor(out)
if pg.rank() == 0:
task = pg.all_gather(tensor_x, tensor_out)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
task = pg.all_gather(tensor_y, tensor_out)
task.wait()
paddle.device.cuda.synchronize()
out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2],
[out_shape[0]])
assert np.array_equal(tensor_x, out_1)
assert np.array_equal(tensor_y, out_2)
print("test allgather api ok\n")
# test alltoall
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
out1 = np.random.random(self.shape).astype(self.dtype)
out2 = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
tensor_out1 = paddle.to_tensor(out1)
tensor_out2 = paddle.to_tensor(out2)
raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2],
[self.shape[0]])
raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0],
[self.shape[0] // 2])
if pg.rank() == 0:
task = pg.alltoall(tensor_x, tensor_out1)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
task = pg.alltoall(tensor_y, tensor_out2)
task.wait()
paddle.device.cuda.synchronize()
out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2],
[self.shape[0]])
out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
if pg.rank() == 0:
assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
else:
assert np.array_equal(out2_1, raw_tensor_x_2)
print("test alltoall api ok\n")
# test Reduce
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
sum_result = tensor_x + tensor_y
if pg.rank() == 0:
task = pg.reduce(tensor_x, 0)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
task = pg.reduce(tensor_y, 0)
task.wait()
paddle.device.cuda.synchronize()
if pg.rank() == 0:
assert np.array_equal(tensor_x, sum_result)
print("test reduce sum api ok\n")
# test Scatter
# rank 0
in_shape = list(self.shape)
in_shape[0] *= 2
x = np.random.random(in_shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
if pg.rank() == 0:
task = pg.scatter(tensor_x, tensor_y, 0)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
task = pg.scatter(tensor_x, tensor_y, 0)
task.wait()
paddle.device.cuda.synchronize()
out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]])
out2 = paddle.slice(tensor_x, [0], [self.shape[0]],
[self.shape[0] * 2])
if pg.rank() == 0:
assert np.array_equal(tensor_y, out1)
else:
assert np.array_equal(tensor_y, out2)
print("test scatter api ok\n")
class TestProcessGroupFp16(TestProcessGroupFp32):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()
def config(self):
self.dtype = "float16"
self.shape = (4, 20, 20)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestProcessGroup(TestMultipleGpus):
def test_process_group_nccl(self):
self.run_mnist_2gpu('process_group_hccl.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册