diff --git a/paddle/fluid/distributed/collective/BKCLTools.cc b/paddle/fluid/distributed/collective/BKCLTools.cc new file mode 100644 index 0000000000000000000000000000000000000000..05372a3a758fc18cfdb8014b2ac2cea37054512b --- /dev/null +++ b/paddle/fluid/distributed/collective/BKCLTools.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/collective/BKCLTools.h" + +#include "paddle/fluid/distributed/collective/Types.h" + +namespace paddle { +namespace distributed { + +BKCLOp ToBKCLRedType(ReduceOp reduction) { + static const std::map red_type = { + {ReduceOp::SUM, BKCL_ADD}, + }; + auto it = red_type.find(reduction); + PADDLE_ENFORCE_EQ(it != red_type.end(), + true, + platform::errors::InvalidArgument( + "Invalid bkcl reduction. Must be BKCL_ADD")); + return it->second; +} + +std::string SerializeBKCLUniqueId(const BKCLUniqueId& bkclID) { + const uint8_t* bytes = reinterpret_cast(&bkclID); + std::ostringstream oss; + for (auto i = 0; i < BKCL_UNIQUE_ID_BYTES; ++i) { + oss << std::hex << static_cast(bytes[i]); + } + return oss.str(); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/BKCLTools.h b/paddle/fluid/distributed/collective/BKCLTools.h new file mode 100644 index 0000000000000000000000000000000000000000..e08bb61438c88fa9b3afae08607414bbc9f806d1 --- /dev/null +++ b/paddle/fluid/distributed/collective/BKCLTools.h @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/distributed/collective/Types.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" + +namespace paddle { +namespace distributed { +using XPUContext = phi::XPUContext; + +#define BKCLCHECK(cmd) \ + do { \ + BKCLResult_t r = cmd; \ + if (r != BKCL_SUCCESS) { \ + printf("Failed, BKCL error %s:%d '%d'\n", __FILE__, __LINE__, r); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +class XPUEventManager { + public: + XPUEventManager() {} + + ~XPUEventManager() { + if (is_created_) { + platform::XPUDeviceGuard guard(device_index_); + xpu_event_destroy(event_); + } + } + + XPUEventManager(const XPUEventManager&) = delete; + XPUEventManager& operator=(const XPUEventManager&) = delete; + + XPUEventManager(XPUEventManager&& other) { + std::swap(is_created_, other.is_created_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + } + + XPUEventManager& operator=(XPUEventManager&& other) { + std::swap(is_created_, other.is_created_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + return *this; + } + + bool IsCreated() const { return is_created_; } + bool DeviceId() const { return device_index_; } + xpuEventHandle GetRawXpuEvent() const { return event_; } + + void Record(const XPUContext& ctx) { + auto device_index = ctx.GetPlace().device; + if (!is_created_) { + CreateEvent(device_index); + } + PADDLE_ENFORCE_EQ(device_index, + device_index_, + platform::errors::PreconditionNotMet( + "XPUContext's device %d does not match" + "Event's device %d", + device_index, + device_index_)); + + platform::XPUDeviceGuard guard(device_index_); + PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event_, ctx.stream())); + } + + void Block(const XPUContext& ctx) const { + if (is_created_) { + auto device_index = ctx.GetPlace().device; + PADDLE_ENFORCE_EQ(device_index, + device_index_, + platform::errors::PreconditionNotMet( + "XPUContext's device %d does not match" + "Event's device %d", + device_index, + device_index_)); + platform::XPUDeviceGuard guard(device_index_); + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(ctx.stream(), event_)); + } + } + + private: + bool is_created_{false}; + xpuEventHandle event_{}; + int8_t device_index_{0}; + + private: + void CreateEvent(int device_index) { + device_index_ = device_index; + platform::XPUDeviceGuard guard(device_index); + + PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_create(&event_)); + + is_created_ = true; + } +}; + +BKCLOp ToBKCLRedType(ReduceOp reduction); +std::string SerializeBKCLUniqueId(const BKCLUniqueId& bkclId); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index e4d7b55d13cae04e5b77358c80ad495a93e9a89e..b57808d32a58517deb79daf57d9f9a436442b209 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -44,6 +44,14 @@ if(WITH_NCCL OR WITH_RCCL) endif() endif() +if(WITH_XPU_BKCL) + cc_library( + processgroup_bkcl + SRCS ProcessGroupBKCL.cc BKCLTools.cc Common.cc + DEPS processgroup place enforce collective_helper device_context + dense_tensor) +endif() + if(WITH_MPI) cc_library( processgroup_mpi diff --git a/paddle/fluid/distributed/collective/Common.cc b/paddle/fluid/distributed/collective/Common.cc index d968c99e479fb252be5badf150a4be8efbefe0c9..d5cac8ec687ad8ba46e65f68ecdb0f99b73d57a6 100644 --- a/paddle/fluid/distributed/collective/Common.cc +++ b/paddle/fluid/distributed/collective/Common.cc @@ -58,5 +58,12 @@ bool CheckTensorsInCustomPlace(const std::vector& tensors, }); } +bool CheckTensorsInXPUPlace(const std::vector& tensors) { + return std::all_of( + tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) { + return platform::is_xpu_place(t.place()); + }); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/Common.h b/paddle/fluid/distributed/collective/Common.h index 38a3100b6ebaa9a3033e0f955079400318010fc1..0cb1b6f0397a442c638ff15378a2841a67512801 100644 --- a/paddle/fluid/distributed/collective/Common.h +++ b/paddle/fluid/distributed/collective/Common.h @@ -33,5 +33,7 @@ bool CheckTensorsInCudaPlace(const std::vector& tensors); bool CheckTensorsInCustomPlace(const std::vector& tensors, const std::string& dev_type); +bool CheckTensorsInXPUPlace(const std::vector& tensors); + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc new file mode 100644 index 0000000000000000000000000000000000000000..40f2172b374ca2e882e2ea6028ff7cf4d29ef5ce --- /dev/null +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc @@ -0,0 +1,523 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/collective/ProcessGroupBKCL.h" + +#include "paddle/fluid/distributed/collective/BKCLTools.h" +#include "paddle/fluid/distributed/collective/Common.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace distributed { +using XPUDeviceContext = paddle::platform::XPUDeviceContext; + +ProcessGroupBKCL::BKCLTask::BKCLTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream) + : TaskStream(rank, comm_type, sync_op, use_calc_stream), place_(place) { + comm_event_ = std::make_shared(); +} + +ProcessGroupBKCL::BKCLTask::~BKCLTask() {} + +bool ProcessGroupBKCL::BKCLTask::IsCompleted() { + LOG_FIRST_N(WARNING, 1) << "XPU do not support event query now."; + return true; +} + +// TODO(sheniang03): Add timeout for wait, now timeout unused +bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) { + // Warning here when use calc stream but also invoke waiting explicitly. + if (UseCalcStream()) { + VLOG(3) << "Warning: The communication is on calc stream, wait here is " + "useless."; + return true; + } + + const auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + comm_event_->Block(*calc_ctx); + + if (barrier_) { + // If we use the work to do barrier, we should block cpu + platform::XPUDeviceGuard guard(place_.GetDeviceId()); + xpu_wait(); + } + return true; +} + +// Same as Wait +void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); } + +ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr& store, + int rank, + int size, + const platform::Place& place, + int gid) + : ProcessGroupStream(rank, size, place, gid), store_(store) { + platform::SetXPUDeviceId(place_.device); +} + +void ProcessGroupBKCL::GroupStart() { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); +} + +void ProcessGroupBKCL::GroupEnd() { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); +} + +std::shared_ptr ProcessGroupBKCL::CreateTask( + const Place& place, + int rank, + CommType comm_type, + bool is_sync, + bool use_calc_stream) { + return std::make_shared( + place, rank, comm_type, is_sync, use_calc_stream); +} + +void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) { + auto key = "ProcessGroupBKCL/bkcl_ids/" + std::to_string(gid_) + "/0"; + if (rank_ == 0) { + auto id = std::vector( + reinterpret_cast(bkcl_id), + reinterpret_cast(bkcl_id) + BKCL_UNIQUE_ID_BYTES); + store_->set(key, id); + } else { + const auto& ret = store_->get(key); + std::memcpy(bkcl_id, ret.data(), ret.size()); + } +} + +void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, + const std::string& place_key) { + BKCLUniqueId bkcl_id; + if (rank_ == 0) { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id)); + } + BroadcastUniqueBKCLID(&bkcl_id); + + VLOG(3) << "init bkcl rank: " << rank_ << ", nranks: " << size_ + << ", place: " << place_key + << ", bkcl uniqueid: " << SerializeBKCLUniqueId(bkcl_id); + + calc_event_ = std::make_shared(); + auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + // must use XPUDeviceContext here to make sure XPUContext::Init() is called + auto comm_ctx = std::make_unique(place); + BKCLContext_t bkcl_comm; + BKCLCHECK(bkcl_init_rank(&bkcl_comm, GetRank(), GetSize(), &bkcl_id)); + comm_ctx->SetBkclContext(bkcl_comm); + + place_to_calc_ctx_[place_key] = calc_ctx; + place_to_comm_ctx_[place_key] = std::move(comm_ctx); +} + +void ProcessGroupBKCL::SyncCalcStream(const Place& place) { + const std::string& key = GetKeyFromPlace(place); + const auto* calc_ctx = place_to_calc_ctx_[key]; + const auto* comm_ctx = place_to_comm_ctx_[key].get(); + calc_event_->Record(*calc_ctx); + calc_event_->Block(*comm_ctx); +} + +template +std::shared_ptr ProcessGroupBKCL::Collective( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + Fn fn, + CommType op_type, + bool sync_op, + bool use_calc_stream) { + const auto& place = in_tensor.place(); + const auto& key = GetKeyFromPlace(place); + + if (!calc_event_) { + CreateBKCLEnvCache(place, key); + } + + if (!use_calc_stream) { + SyncCalcStream(place); + } + + auto task = CreateTask(place, rank_, op_type, sync_op, use_calc_stream); + + const auto* calc_ctx = place_to_calc_ctx_[key]; + const auto& comm_ctx = place_to_comm_ctx_[key]; + auto bkcl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); + fn(out_tensor, in_tensor, comm_ctx->bkcl_context(), bkcl_stream); + + if (!use_calc_stream) { + task->comm_event_->Record(*comm_ctx.get()); + } + + return task; +} + +std::shared_ptr ProcessGroupBKCL::AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + return bkcl_all_reduce( + comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + }, + CommType::ALLREDUCE, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupBKCL::Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + int root = opts.source_rank + opts.source_root; + return bkcl_broadcast(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + root, + stream); + }, + CommType::BROADCAST, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupBKCL::AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + return bkcl_all_gather( + comm, + input.data(), + input.numel(), + output->data(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + }, + CommType::ALLGATHER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupBKCL::Barrier( + const BarrierOptions& opts) { + auto allocator = std::unique_ptr( + new paddle::experimental::DefaultAllocator(place_)); + phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); + phi::DenseTensor barrier_tensor{allocator.get(), meta}; + + auto task = AllReduce(&barrier_tensor, + barrier_tensor, + {}, + /*sync_op*/ true, + /*use_calc_stream*/ false); + auto bkcl_task = dynamic_cast(task.get()); + bkcl_task->barrier_ = true; + return task; +} + +const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext( + const Place& place) const { + return GetDeviceContext(place, /*use_calc_stream*/ false); +} + +const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext( + const Place& place, bool use_calc_stream) const { + const std::string& key = GetKeyFromPlace(place); + if (use_calc_stream) { + const auto& iter = place_to_calc_ctx_.find(key); + return *iter->second; + } else { + const auto& iter = place_to_comm_ctx_.find(key); + PADDLE_ENFORCE_NE(iter, + place_to_comm_ctx_.end(), + platform::errors::InvalidArgument( + "Cannot find device context in process group.")); + return *iter->second; + } +} + +// below are old apis +std::shared_ptr ProcessGroupBKCL::AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& opts) { + PADDLE_ENFORCE_EQ( + in_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + out_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + return Collective( + &out_tensors[0], + in_tensors[0], + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + return bkcl_all_reduce( + comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + }, + CommType::ALLREDUCE, + /*sync_op*/ true, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupBKCL::AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& opts, + bool sync_op) { + PADDLE_ENFORCE_EQ( + in_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + out_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + return Collective( + &out_tensors[0], + in_tensors[0], + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + return bkcl_all_reduce( + comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + ToBKCLRedType(opts.reduce_op), + stream); + }, + CommType::ALLREDUCE, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupBKCL::Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts) { + PADDLE_ENFORCE_EQ( + in_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + out_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + + return Collective( + &out_tensors[0], + in_tensors[0], + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + const auto root = + opts.source_rank * in_tensors.size() + opts.source_root; + return bkcl_broadcast(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + root, + stream); + }, + CommType::BROADCAST, + /*sync_op*/ true, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupBKCL::Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts, + bool sync_op) { + PADDLE_ENFORCE_EQ( + in_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + out_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + + return Collective( + &out_tensors[0], + in_tensors[0], + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + const auto root = + opts.source_rank * in_tensors.size() + opts.source_root; + return bkcl_broadcast(comm, + input.data(), + output->data(), + input.numel(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + root, + stream); + }, + CommType::BROADCAST, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupBKCL::AllGather( + std::vector& in_tensors, + std::vector& out_tensors) { + PADDLE_ENFORCE_EQ( + in_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + out_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(out_tensors), + true, + platform::errors::InvalidArgument("All outputs should be in XPUPlace.")); + return Collective( + &out_tensors[0], + in_tensors[0], + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + return bkcl_all_gather( + comm, + input.data(), + input.numel(), + output->data(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + }, + CommType::ALLGATHER, + /*sync_op*/ true, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupBKCL::AllGather( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op) { + PADDLE_ENFORCE_EQ( + in_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + out_tensors.size(), + 1, + platform::errors::InvalidArgument( + "BKCL only support single tensor collective communication.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInXPUPlace(out_tensors), + true, + platform::errors::InvalidArgument("All outputs should be in XPUPlace.")); + return Collective( + &out_tensors[0], + in_tensors[0], + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + BKCLContext_t comm, + const XPUStream& stream) { + return bkcl_all_gather( + comm, + input.data(), + input.numel(), + output->data(), + platform::ToBKCLDataType( + framework::TransToProtoVarType(input.type())), + stream); + }, + CommType::ALLGATHER, + sync_op, + /*use_calc_stream*/ false); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.h b/paddle/fluid/distributed/collective/ProcessGroupBKCL.h new file mode 100644 index 0000000000000000000000000000000000000000..0041d903de78af7ea02c3efa89e5b8c506cd81ff --- /dev/null +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.h @@ -0,0 +1,179 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/collective/ProcessGroupStream.h" +#include "paddle/fluid/distributed/store/store.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/device_context.h" + +#if defined(PADDLE_WITH_XPU) +#include "paddle/fluid/distributed/collective/BKCLTools.h" +#endif + +constexpr const char* BKCL_BACKEND_NAME = "BKCL"; + +namespace paddle { +namespace distributed { + +using Place = paddle::platform::Place; + +// BKCL funcs use separate communication stream by default +class ProcessGroupBKCL : public ProcessGroupStream { + public: + class BKCLTask final : public ProcessGroupStream::TaskStream, + public std::enable_shared_from_this { + public: + BKCLTask(const Place& place, + int rank, + CommType CommType, + bool sync_op, + bool use_calc_stream); + virtual ~BKCLTask(); + + // TODO(zhangxiaoci): XPU do not support event query for now + bool IsCompleted() override; + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; + void Synchronize() override; + + void SynchronizeStreams(); + + public: + bool barrier_{false}; + std::shared_ptr comm_event_; // event on comm stream + + private: + Place place_; + }; + + public: + ProcessGroupBKCL(const std::shared_ptr& store, + int rank, + int size, + const platform::Place& place, + int gid); + + std::string GetBackendName() const override { + return std::string(BKCL_BACKEND_NAME); + } + + const phi::DeviceContext& GetDeviceContext(const Place& place) const override; + + const phi::DeviceContext& GetDeviceContext( + const Place& place, bool use_calc_stream) const override; + + std::shared_ptr AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) override; + + static void GroupStart(); + + static void GroupEnd(); + + BKCLContext_t BKCLComm(const Place& place) const; + + // below are old apis + std::shared_ptr AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& = AllreduceOptions()) override; + + std::shared_ptr AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& options, + bool sync_op) override; + + std::shared_ptr Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& = BroadcastOptions()) override; + + std::shared_ptr Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions&, + bool sync_op) override; + + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op) override; + + private: + std::shared_ptr CreateTask(const Place& place, + int rank, + CommType op_type, + bool sync_op, + bool use_calc_stream); + + void BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id); // NOLINT + + void CreateBKCLEnvCache(const Place& place, const std::string& place_key); + + template + std::shared_ptr Collective( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + Fn fn, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + + void SyncCalcStream(const Place& place); + + private: + std::shared_ptr store_; + std::mutex mutex_; + std::shared_ptr calc_event_; // event on calc stream + std::unordered_map place_to_calc_ctx_; + std::unordered_map> + place_to_comm_ctx_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index f04585ce1710f45201c3fadb87d321f72ecf11b5..f8850660640c3f229da0ee7a282b211bdc6d5bfe 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -299,6 +299,57 @@ static void SplitTensorsWithType(const DeviceContext &context, } } +#ifdef PADDLE_WITH_XPU_BKCL +// context is used to select the stream for concat +template <> +void ConcatTensorsWithType( + const platform::XPUDeviceContext &context, + const std::vector &dense_tensors_, + Tensor *p_dense_contents, + phi::DataType type) { + switch (type) { + case phi::DataType::FLOAT32: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; + case phi::DataType::FLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors for " + "allreduce.", + type)); + } +} + +// context is used to select the stream for split +template <> +void SplitTensorsWithType( + const platform::XPUDeviceContext &context, + Tensor *p_dense_contents, + std::vector *p_dense_tensors, + phi::DataType type) { + switch (type) { + case phi::DataType::FLOAT32: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; + case phi::DataType::FLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it splits tensors for " + "allreduce.", + type)); + } +} +#endif + void EagerGroup::ConcatTensors(const platform::Place &place) { dense_contents_ = paddle::experimental::empty(IntArray({all_length_}), dtype_, place); @@ -325,6 +376,17 @@ void EagerGroup::ConcatTensors(const platform::Place &place) { "Paddle can't concat grad tensors since it's not compiled with " "CUSTOM_DEVICE," "Please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); +#endif + } else if (platform::is_xpu_place(place)) { +#if defined(PADDLE_WITH_XPU_BKCL) + auto *default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + ConcatTensorsWithType( + *default_ctx, dense_tensors_, &dense_contents_, dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat grad tensors since it's not compiled with BKCL," + "Please recompile or reinstall Paddle with BKCL support.")); #endif } else if (platform::is_cpu_place(place)) { auto *default_ctx = static_cast( @@ -368,6 +430,17 @@ void EagerGroup::SplitTensorsDev(const platform::DeviceContext &context) { "Paddle can't split grad tensor since it's not compiled with " "CUSTOM_DEVICE," "Please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); +#endif + } else if (platform::is_xpu_place(place)) { +#if defined(PADDLE_WITH_XPU_BKCL) + auto *default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + SplitTensorsWithType( + *default_ctx, &dense_contents_, &dense_tensors_, dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't split grad tensor since it's not compiled with BKCL," + "Please recompile or reinstall Paddle with BKCL support.")); #endif } else if (platform::is_cpu_place(place)) { SplitTensorsWithType(static_cast(context), diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index a74b345ec835fe8ad306a522a7e0b9a0a0e40159..7e5a62a275a1a1cc4f046a0c73c330657a012226 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -78,6 +78,7 @@ class ConcatFunctor { const std::vector& input, int axis, phi::DenseTensor* output) { + using XPUType = typename XPUTypeTrait::Type; int dev_id = context.GetPlace().GetDeviceId(); platform::XPUDeviceGuard guard(dev_id); @@ -93,13 +94,24 @@ class ConcatFunctor { xdims_list[i] = tmp_dims; } - std::vector ptrs; + std::vector ptrs; for (int i = 0; i < num; ++i) { - ptrs.push_back(input[i].data()); + if (input[i].place() != context.GetPlace()) { + // data not on xpu, probably on cpu. move it now + phi::DenseTensor tmp_data = input[i]; + context.template Alloc(&tmp_data); + ptrs.push_back(reinterpret_cast(tmp_data.data())); + } else { + ptrs.push_back(reinterpret_cast(input[i].data())); + } } + context.template Alloc(output); - auto r = xpu::concat( - context.x_context(), ptrs, output->data(), xdims_list, axis); + auto r = xpu::concat(context.x_context(), + ptrs, + reinterpret_cast(output->data()), + xdims_list, + axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, @@ -119,6 +131,7 @@ class SplitFunctor { const std::vector& ref_inputs, const int axis, std::vector* outputs) { + using XPUType = typename XPUTypeTrait::Type; int dev_id = context.GetPlace().GetDeviceId(); platform::XPUDeviceGuard guard(dev_id); @@ -140,17 +153,24 @@ class SplitFunctor { } xdims_list[axis] = total_length; - std::vector ptrs(num); + std::vector ptrs(num); for (int i = 0; i < num; ++i) { - ptrs[i] = outputs->at(i)->data(); + context.template Alloc(outputs->at(i)); + ptrs[i] = reinterpret_cast(outputs->at(i)->data()); + } + phi::DenseTensor tmp_data = input; + if (input.place() != context.GetPlace()) { + // data not on xpu, probably on cpu. move it now + context.template Alloc(&tmp_data); } - auto r = xpu::split(context.x_context(), - input.data(), - ptrs, - xdims_list, - split_list, - axis); + auto r = xpu::split( + context.x_context(), + reinterpret_cast(tmp_data.data()), + ptrs, + xdims_list, + split_list, + axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, @@ -347,6 +367,7 @@ FOR_ALL_TYPES(DEFINE_FUNCTOR); template class SplitFunctor; DEFINE_XPU_FUNCTOR(float) +DEFINE_XPU_FUNCTOR(platform::float16) #endif #ifdef PADDLE_WITH_ASCEND_CL diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 32244f4eb136cf802f08534f8f5f47afefb89574..4ffb76ad62bdebb4ac68ea0f9983ef8bc509a707 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -353,12 +353,6 @@ BKCLComm* BKCLCommContext::AssignBKCLComm( BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) { std::unique_ptr dev_ctx( new XPUDeviceContext(XPUPlace(dev_id))); - // used in BKCL as comm_stream, for every dev_id there is - // a comm_stream at each ring. this stream is passed as input var - // when calling collective comm commands like bkcl_all_reduce - XPUStream comm_stream; - PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_stream)); - dev_ctx->SetXPUStream(comm_stream); BKCLCommImpl* c = new BKCLCommImpl; c->set_ring_id(ring_id); diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index eef7eba7bc7d242ab82c114d2b6406ef5ba3d469..8773ae273a69e4521e8bee6a67536dbfd07cb5f8 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -535,6 +535,11 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"split", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"split_with_num", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 3b9b0a9ca6bca8db829cbf190dc81533d068b4bf..8cc5f4f4d2ee72fbec4b10412447bcca988884fe 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -137,6 +137,7 @@ set(PYBIND_SRCS generator_py.cc communication.cc cuda_streams_py.cc + xpu_streams_py.cc jit.cc auto_parallel_py.cc op_function1.cc @@ -161,6 +162,9 @@ if(WITH_PYTHON) set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_heter) endif() endif() + if(WITH_XPU_BKCL) + set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_bkcl) + endif() if(WITH_GLOO) set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo) endif() diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 153396a104239302d4e714e9f0cc764d4f90efa5..4cc1a0607e045d13b47772187a134f79c6b1de83 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -59,6 +59,10 @@ limitations under the License. */ #include "paddle/fluid/distributed/store/tcp_store.h" #endif +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/distributed/collective/ProcessGroupBKCL.h" +#endif + #include "paddle/phi/kernels/sync_batch_norm_kernel.h" namespace py = pybind11; @@ -1328,6 +1332,24 @@ void BindDistributed(py::module *m) { #endif +#if defined(PADDLE_WITH_XPU_BKCL) + auto processGroupBKCL = + py::class_>( + *m, "ProcessGroupBKCL", ProcessGroup) + .def(py::init &, + int, + int, + const platform::XPUPlace &, + int>(), + py::arg("store"), + py::arg("rank"), + py::arg("world_size"), + py::arg("place"), + py::arg("group_id") = 0, + py::call_guard()); +#endif + py::class_>(*m, "task") .def("is_completed", &distributed::ProcessGroup::Task::IsCompleted) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index eaa71a23b679b608474c2c38cb70220a99ebc939..b4d175efd2b5699f71342a1d1e333b5340b1cff9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -92,6 +92,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/io.h" #include "paddle/fluid/pybind/jit.h" +#include "paddle/fluid/pybind/xpu_streams_py.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/lod_utils.h" #include "paddle/utils/none.h" @@ -609,6 +610,7 @@ PYBIND11_MODULE(libpaddle, m) { BindEager(&m); BindEagerStringTensor(&m); BindCudaStream(&m); + BindXpuStream(&m); BindJit(&m); // Not used, just make sure cpu_info.cc is linked. diff --git a/paddle/fluid/pybind/xpu_streams_py.cc b/paddle/fluid/pybind/xpu_streams_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..044b954ce6b655499f81c17127ffaa9875e61254 --- /dev/null +++ b/paddle/fluid/pybind/xpu_streams_py.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/xpu_streams_py.h" + +#include +#include + +#include "paddle/fluid/platform/device_event_base.h" +#include "paddle/fluid/platform/event.h" +#if defined(PADDLE_WITH_XPU) +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#endif + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +void BindXpuStream(py::module *m_ptr) { + auto &m = *m_ptr; + + // Bind Methods + m.def("_xpu_device_synchronize", [](int device_id) { +#if defined(PADDLE_WITH_XPU) + if (device_id == -1) { + device_id = paddle::platform::GetXPUCurrentDeviceId(); + } + int curr_device_id = paddle::platform::GetXPUCurrentDeviceId(); + paddle::platform::SetXPUDeviceId(device_id); + PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); + paddle::platform::SetXPUDeviceId(curr_device_id); +#else + PADDLE_THROW(platform::errors::Unavailable( + "Paddle is not compiled with XPU. Cannot visit device synchronize.")); +#endif + }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/xpu_streams_py.h b/paddle/fluid/pybind/xpu_streams_py.h new file mode 100644 index 0000000000000000000000000000000000000000..a88857192127fe4fbb39685b09c1060f6f23be39 --- /dev/null +++ b/paddle/fluid/pybind/xpu_streams_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindXpuStream(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index 2735e2a4208bafe7b05c439003f26e9d9c8f7f91..7257f3f20b06b10a7d2aebddb9bbbd093f8795c4 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -17,6 +17,7 @@ #include #include "paddle/phi/api/ext/exception.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/common/place.h" #include "xpu/runtime.h" #include "xpu/runtime_ex.h" @@ -59,6 +60,12 @@ struct XPUContext::Impl { ~Impl() { if (owned_ && context_ != nullptr) { + backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); + // manually destroy XPUStream here until xpu::api integrates this work + // into Context dtor + xpu_wait(context_->xpu_stream); + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_destroy(context_->xpu_stream)); + context_->xpu_stream = nullptr; xpu::destroy_context(context_); context_ = nullptr; } @@ -66,8 +73,6 @@ struct XPUContext::Impl { const Place& GetPlace() const { return place_; } - void SetStream(XPUStream stream) { context_->xpu_stream = stream; } - XPUStream stream() const { auto s = context_->xpu_stream; PD_CHECK(s != nullptr, "the xpu stream is nullptr."); @@ -85,7 +90,7 @@ struct XPUContext::Impl { } void Wait() const { - backends::xpu::SetXPUDeviceId(place_.GetDeviceId()); + backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); xpu_wait(context_->xpu_stream); } @@ -98,6 +103,7 @@ struct XPUContext::Impl { context_ = xpu::create_context(); xpu_version_ = backends::xpu::get_xpu_version(place_.device); SetL3Cache(); + PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream)); } void SetXContext(xpu::Context* context) { context_ = context; } @@ -123,8 +129,6 @@ XPUContext::~XPUContext() = default; const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); } -void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); } - XPUStream XPUContext::stream() const { return impl_->stream(); } backends::xpu::XPUVersion XPUContext::xpu_version() const { diff --git a/paddle/phi/backends/xpu/xpu_context.h b/paddle/phi/backends/xpu/xpu_context.h index 61967a9a58d39c9f509b16ea437dbd4c8ed868f8..1c12c7e5fe69a490b19e21ed8a1646422593a433 100644 --- a/paddle/phi/backends/xpu/xpu_context.h +++ b/paddle/phi/backends/xpu/xpu_context.h @@ -64,8 +64,6 @@ class XPUContext : public DeviceContext, void SetL3Cache(int l3_size = 14155776); - void SetXPUStream(XPUStream stream); - XPUStream stream() const; static const char* name() { return "XPUContext"; } diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.h b/paddle/phi/kernels/funcs/concat_and_split_functor.h index 4cb15fe539b66b8a6fddccf18d92b95976db2a65..55c48d566a1eaa1865ca0f228e6728905232ae01 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.h +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/utils/data_type.h" diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 5f7524cde591bb46da220926a37f251970680855..756fd8782e96042539d6a37d4fbcdc6b06468bb1 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -168,7 +168,13 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, phi::DenseTensor* tensor, float value) { - PADDLE_THROW(phi::errors::Unimplemented("XPUPlace is not supported")); +#ifdef PADDLE_WITH_XPU + phi::VisitDataType( + tensor->dtype(), + TensorSetConstantXPU(tensor, value, tensor->place())); +#else + PADDLE_THROW(phi::errors::PreconditionNotMet("Not compiled with XPU!")); +#endif } template <> @@ -257,6 +263,8 @@ void set_constant(const paddle::platform::DeviceContext& context, #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // tensor->place().apply_visitor(func); paddle::platform::VisitPlace(tensor->place(), func); +#elif defined(PADDLE_WITH_XPU) + func(phi::XPUPlace()); #else func(phi::CPUPlace()); #endif diff --git a/paddle/phi/kernels/xpu/concat_and_split.cc b/paddle/phi/kernels/xpu/concat_and_split.cc new file mode 100644 index 0000000000000000000000000000000000000000..225f9555b02e6f312d2c4eed89292b16c5823483 --- /dev/null +++ b/paddle/phi/kernels/xpu/concat_and_split.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" + +namespace phi { +namespace funcs { + +using XPUDeviceGuard = phi::backends::xpu::XPUDeviceGuard; + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const XPUContext& context, + const std::vector& input, + int axis, + phi::DenseTensor* output) { + using XPUType = typename XPUTypeTrait::Type; + int dev_id = context.GetPlace().GetDeviceId(); + XPUDeviceGuard guard(dev_id); + + int num = input.size(); + auto input_dims = input[0].dims(); + + std::vector> xdims_list(num); + for (int i = 0; i < num; ++i) { + std::vector tmp_dims(input_dims.size()); + for (int j = 0; j < input_dims.size(); ++j) { + tmp_dims[j] = input[i].dims()[j]; + } + xdims_list[i] = tmp_dims; + } + + std::vector ptrs; + for (int i = 0; i < num; ++i) { + if (input[i].place() != context.GetPlace()) { + // data not on xpu, probably on cpu. move it now + phi::DenseTensor tmp_data = input[i]; + context.template Alloc(&tmp_data); + ptrs.push_back(reinterpret_cast(tmp_data.data())); + } else { + ptrs.push_back(reinterpret_cast(input[i].data())); + } + } + context.template Alloc(output); + + auto r = xpu::concat(context.x_context(), + ptrs, + reinterpret_cast(output->data()), + xdims_list, + axis); + PADDLE_ENFORCE_EQ( + r, + XPU_SUCCESS, + paddle::platform::errors::External( + "XPU API return wrong value[%d %s], please check whether " + "Baidu Kunlun Card is properly installed.", + r, + XPUAPIErrorMsg[r])); + } +}; + +template +class SplitFunctor { + public: + void operator()(const XPUContext& context, + const phi::DenseTensor& input, + const std::vector& ref_inputs, + const int axis, + std::vector* outputs) { + using XPUType = typename XPUTypeTrait::Type; + int dev_id = context.GetPlace().GetDeviceId(); + XPUDeviceGuard guard(dev_id); + + auto& ins = ref_inputs; + + int num = ins.size(); + auto input_dims = ins[0]->dims(); + std::vector split_list(num); + std::vector xdims_list(input_dims.size()); + int total_length = 0; + for (int i = 0; i < num; ++i) { + split_list[i] = ins[i]->dims()[axis]; + total_length += ins[i]->dims()[axis]; + } + + for (int i = 0; i < input_dims.size(); ++i) { + if (i == axis) continue; + xdims_list[i] = input_dims[i]; + } + xdims_list[axis] = total_length; + + std::vector ptrs(num); + for (int i = 0; i < num; ++i) { + context.template Alloc(outputs->at(i)); + ptrs[i] = reinterpret_cast(outputs->at(i)->data()); + } + phi::DenseTensor tmp_data = input; + if (input.place() != context.GetPlace()) { + // data not on xpu, probably on cpu. move it now + context.template Alloc(&tmp_data); + } + + auto r = xpu::split( + context.x_context(), + reinterpret_cast(tmp_data.data()), + ptrs, + xdims_list, + split_list, + axis); + PADDLE_ENFORCE_EQ( + r, + XPU_SUCCESS, + paddle::platform::errors::External( + "XPU API return wrong value[%d %s], please check whether " + "Baidu Kunlun Card is properly installed.", + r, + XPUAPIErrorMsg[r])); + } +}; + +#define DEFINE_XPU_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class SplitFunctor; + +DEFINE_XPU_FUNCTOR(float) +DEFINE_XPU_FUNCTOR(phi::dtype::float16) + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/xpu/split_kernel.cc b/paddle/phi/kernels/xpu/split_kernel.cc index 352d6f857c0d46eb0f0ab8cf285de2df8bac9756..674182620bccc5952d4ec32989d959b5c24f4173 100644 --- a/paddle/phi/kernels/xpu/split_kernel.cc +++ b/paddle/phi/kernels/xpu/split_kernel.cc @@ -25,22 +25,23 @@ void SplitKernel(const Context& dev_ctx, const IntArray& sections, const Scalar& axis_scalar, std::vector outs) { + using XPUType = typename XPUTypeTrait::Type; int axis = axis_scalar.to(); auto in_dims = x.dims(); auto input_shape = vectorize(in_dims); - std::vector out_ptrs; + std::vector out_ptrs; std::vector split_lists; for (size_t j = 0; j < outs.size(); ++j) { dev_ctx.template Alloc(outs[j]); - out_ptrs.push_back(outs[j]->data()); + out_ptrs.push_back(reinterpret_cast(outs[j]->data())); split_lists.push_back(outs[j]->dims()[axis]); } - int r = xpu::split(dev_ctx.x_context(), - x.data(), - out_ptrs, - input_shape, - split_lists, - axis); + int r = xpu::split(dev_ctx.x_context(), + reinterpret_cast(x.data()), + out_ptrs, + input_shape, + split_lists, + axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "split"); } @@ -62,6 +63,13 @@ void SplitWithNumKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(split, XPU, ALL_LAYOUT, phi::SplitKernel, float, int) {} PD_REGISTER_KERNEL( - split_with_num, XPU, ALL_LAYOUT, phi::SplitWithNumKernel, float, int) {} + split, XPU, ALL_LAYOUT, phi::SplitKernel, float, int, phi::dtype::float16) { +} +PD_REGISTER_KERNEL(split_with_num, + XPU, + ALL_LAYOUT, + phi::SplitWithNumKernel, + float, + int, + phi::dtype::float16) {} diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index f8d5dbd8b9dbbdf26f8c696ccb10ce70bd5e1183..d961fbfdda7f08846401731ac2041ef4e7bf3235 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -22,6 +22,7 @@ from paddle.fluid.framework import is_compiled_with_cinn # noqa: F401 from paddle.fluid.framework import is_compiled_with_cuda # noqa: F401 from paddle.fluid.framework import is_compiled_with_rocm # noqa: F401 from . import cuda +from . import xpu __all__ = [ # noqa 'get_cudnn_version', diff --git a/python/paddle/device/xpu/__init__.py b/python/paddle/device/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a928a0f7c04052598d7074936aa28b4f09459bcb --- /dev/null +++ b/python/paddle/device/xpu/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import core + +__all__ = [ + 'synchronize', +] + + +def synchronize(device=None): + ''' + Wait for the compute on the given XPU device to finish. + + Parameters: + device(paddle.XPUPlace()|int, optional): The device or the ID of the device. + If device is None, the device is the current device. Default: None. + + Examples: + .. code-block:: python + + # required: xpu + import paddle + + paddle.device.xpu.synchronize() + paddle.device.xpu.synchronize(0) + paddle.device.xpu.synchronize(paddle.XPUPlace(0)) + + ''' + + device_id = -1 + + if device is not None: + if isinstance(device, int): + device_id = device + elif isinstance(device, core.XPUPlace): + device_id = device.get_device_id() + else: + raise ValueError("device type must be int or paddle.XPUPlace") + + return core._xpu_device_synchronize(device_id) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cd5b5ac91450b18c4519464eca8cd2c0de4cd746..91de7de4d45c84c1c1dad3bf1d3abc74acfe4246 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -67,7 +67,7 @@ _group_map_backend = {} # Name of the default group for init_parallel_env _default_group_name = "_default_pg" -_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl'] +_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl', 'bkcl'] _default_store = None # the default tcp store _default_backend = None _default_timeout = datetime.timedelta(seconds=1800) @@ -170,6 +170,9 @@ def _new_process_group_impl( elif backend == "xccl": place = core.CustomPlace(genv.device_type, genv.device_id) pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id) + elif backend == "bkcl": + place = core.XPUPlace(genv.device_id) + pg = core.ProcessGroupBKCL(store, rank, world_size, place, group_id) elif backend == "heter": place = None if core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 590d8e951f128ee644b19f106e572b896702d05a..839bbc8026dc1b8bc645f8e6f57e598e03a2ce6b 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -300,6 +300,7 @@ try: from .libpaddle import _promote_types_if_complex_exists from .libpaddle import _set_cached_executor_build_strategy from .libpaddle import _device_synchronize + from .libpaddle import _xpu_device_synchronize from .libpaddle import _get_current_stream from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _set_current_stream diff --git a/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_dataparallel_with_pylayer.py b/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_dataparallel_with_pylayer.py new file mode 100644 index 0000000000000000000000000000000000000000..18ac7c88e194793a462ff24e81a22f1c6ffab248 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_dataparallel_with_pylayer.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import numpy as np +import paddle.distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) + +batch = 5 +in_dim = 20 +out_dim = 10 + + +class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + (y,) = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + +class SimpleNet(paddle.nn.Layer): + def __init__(self, train_id, model_id): + super(SimpleNet, self).__init__() + self.w = self.create_parameter(shape=[in_dim, batch], dtype="float32") + self.linear = paddle.nn.Linear(in_dim, out_dim) + self.tanh = paddle.tanh + + self.trainer_id = train_id + self.model_id = model_id + + def forward(self, inputs): + if self.model_id == 0: + inputs = cus_tanh.apply(inputs) + else: + inputs = self.tanh(inputs) + + inputs = paddle.matmul(self.w, inputs) + return self.linear(inputs) + + +class TestDistTraning(unittest.TestCase): + def test_multiple_xpus(self): + self.trainer_id = dist.get_rank() + dist.init_parallel_env() + + model_a = SimpleNet(self.trainer_id, 0) + model_b = SimpleNet(self.trainer_id, 1) + + state_dict = model_a.state_dict() + model_b.set_state_dict(state_dict) + + model_a = paddle.DataParallel(model_a) + model_b = paddle.DataParallel(model_b) + + for step in range(10): + x_data = np.random.randn(batch, in_dim).astype(np.float32) + x = paddle.to_tensor(x_data) + x.stop_gradient = False + + with model_a.no_sync(): + y_pred_a = model_a(x) + loss_a = y_pred_a.mean() + loss_a.backward() + fused_allreduce_gradients(list(model_a.parameters()), None) + + y_pred_b = model_b(x) + loss_b = y_pred_b.mean() + loss_b.backward() + + self.check_gradient(model_a.parameters()) + self.check_gradient(model_b.parameters()) + + self.check_acc(model_a._layers.w.grad, model_b._layers.w.grad) + + model_a.clear_gradients() + model_b.clear_gradients() + + def check_acc(self, grad, acc_grad): + grad = grad.numpy() if grad is not None else None + acc_grad = acc_grad.numpy() if acc_grad is not None else None + return np.testing.assert_allclose(grad, acc_grad, rtol=1e-6) + + def broadcast_param(self, param, root): + paddle.distributed.broadcast(param, root) + return param + + def check_gradient(self, params): + other_param = [] + for param in params: + if param.trainable and (param._grad_ivar() is not None): + grad = param._grad_ivar() + other_grad = self.broadcast_param(grad.clone(), root=1) + if self.trainer_id == 0: + np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_gradient_check.py b/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_gradient_check.py new file mode 100644 index 0000000000000000000000000000000000000000..b132e0e7e718f40475538aae34ff65f79b99308a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_gradient_check.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import numpy as np +import paddle.distributed as dist +import paddle.fluid as fluid +from paddle.nn import Linear + +paddle.seed(1024) +np.random.seed(2021) + +batch = 5 +in_dim = 10 +out_dim = 20 + + +class SimpleNet(fluid.Layer): + def __init__(self, train_id): + super(SimpleNet, self).__init__() + self.w1 = self.create_parameter( + shape=[in_dim, out_dim], dtype="float32" + ) + self.w2 = self.create_parameter( + shape=[in_dim, out_dim], dtype="float32" + ) + self.share_net = Linear(out_dim, 10) + + self.unused_param = self.create_parameter( + shape=[out_dim, in_dim], dtype="float64" + ) + + # just for test sync_params_buffers + self.register_buffer("queue", paddle.randn([10, 5])) + self.queue = paddle.nn.functional.normalize(self.queue, axis=0) + self.register_buffer("queue_ptr", paddle.zeros([1], 'int64')) + + self.trainer_id = train_id + + def forward(self, x): + is_use = ( + paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0] + and self.trainer_id == 1 + ) + + if is_use: + tmp = paddle.matmul(x, self.w1) + else: + tmp = paddle.matmul(x, self.w2) + + return self.share_net(tmp) + + +class TestDistTraning(unittest.TestCase): + def test_multiple_xpus(self): + dist.init_parallel_env() + self.trainer_id = dist.get_rank() + + model_a = SimpleNet(self.trainer_id) + model_b = SimpleNet(self.trainer_id) + + state_dict = model_a.state_dict() + model_b.set_state_dict(state_dict) + + model_a = paddle.DataParallel(model_a, find_unused_parameters=True) + model_b = paddle.DataParallel(model_b, find_unused_parameters=True) + + ones_input = paddle.ones(shape=(batch, in_dim)) + ones_input.stop_gradient = True + + w1_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') + w2_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') + + for step_id in range(5): + random_input = paddle.rand(shape=(batch, in_dim)) + random_input.stop_gradient = True + + if step_id % 2 == 0: + out_a = model_a(random_input) + out_b = model_b(random_input) + else: + out_a = model_a(ones_input) + out_b = model_b(ones_input) + + out_a.sum().backward() + out_b.sum().backward() + + self.check_gradient(model_a.parameters()) + self.check_gradient(model_b.parameters()) + + # test acc gradient + w1_grad_sum = self.check_acc( + model_a._layers.w1.grad, w1_grad_sum, model_b._layers.w1.grad + ) + w2_grad_sum = self.check_acc( + model_a._layers.w2.grad, w2_grad_sum, model_b._layers.w2.grad + ) + + model_a.clear_gradients() + + def check_acc(self, grad, grad_sum, acc_grad): + if grad is not None: + grad_sum = grad_sum + grad.numpy() + acc_grad = acc_grad.numpy() if acc_grad is not None else None + np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) + return grad_sum + + def print_trainer_0(self, *args): + if self.trainer_id == 0: + print(*args) + + def broadcast_param(self, param, root): + paddle.distributed.broadcast(param, root) + return param + + def check_gradient(self, params): + other_param = [] + for param in params: + if param.trainable and (param._grad_ivar() is not None): + grad = param._grad_ivar() + other_grad = self.broadcast_param(grad.clone(), root=1) + if self.trainer_id == 0: + np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_gradient_check_in_eager_mode.py b/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_gradient_check_in_eager_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4edaf3f3b6d9d5d00bcfbf7d8633e32a13eb81 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/parallel_dygraph_gradient_check_in_eager_mode.py @@ -0,0 +1,150 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import numpy as np +import paddle.distributed as dist +import paddle.fluid as fluid +from paddle.nn import Linear +from paddle.fluid.framework import _test_eager_guard + +paddle.seed(1024) +np.random.seed(2021) + +batch = 5 +in_dim = 10 +out_dim = 20 + + +class SimpleNet(fluid.Layer): + def __init__(self, train_id): + super(SimpleNet, self).__init__() + self.w1 = self.create_parameter( + shape=[in_dim, out_dim], dtype="float32" + ) + self.w2 = self.create_parameter( + shape=[in_dim, out_dim], dtype="float32" + ) + self.share_net = Linear(out_dim, 10) + + self.unused_param = self.create_parameter( + shape=[out_dim, in_dim], dtype="float32" + ) + + # for test sync_params_buffers + self.register_buffer("queue", paddle.randn([10, 5])) + self.queue = paddle.nn.functional.normalize(self.queue, axis=0) + self.register_buffer("queue_ptr", paddle.zeros([1], 'int64')) + + self.trainer_id = train_id + + def forward(self, x): + is_use = ( + paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0] + and self.trainer_id == 1 + ) + + if is_use: + tmp = paddle.matmul(x, self.w1) + else: + tmp = paddle.matmul(x, self.w2) + + return self.share_net(tmp) + + +class TestDistTraning(unittest.TestCase): + def test_multiple_xpus(self): + self.trainer_id = dist.get_rank() + with _test_eager_guard(): + self.pg = dist.init_parallel_env() + + model_a = SimpleNet(self.trainer_id) + model_b = SimpleNet(self.trainer_id) + + state_dict = model_a.state_dict() + model_b.set_state_dict(state_dict) + + model_a = paddle.DataParallel( + model_a, find_unused_parameters=True, group=self.pg + ) + model_b = paddle.DataParallel( + model_b, find_unused_parameters=True, group=self.pg + ) + + ones_input = paddle.ones(shape=(batch, in_dim)) + ones_input.stop_gradient = True + + w1_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') + w2_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') + + for step_id in range(5): + random_input = paddle.rand(shape=(batch, in_dim)) + random_input.stop_gradient = True + + if step_id % 2 == 0: + out_a = model_a(random_input) + out_b = model_b(random_input) + else: + out_a = model_a(ones_input) + out_b = model_b(ones_input) + + out_a.sum().backward() + out_b.sum().backward() + + self.check_gradient(model_a.parameters()) + self.check_gradient(model_b.parameters()) + + # test acc gradient + w1_grad_sum = self.check_acc( + model_a._layers.w1.grad, + w1_grad_sum, + model_b._layers.w1.grad, + ) + w2_grad_sum = self.check_acc( + model_a._layers.w2.grad, + w2_grad_sum, + model_b._layers.w2.grad, + ) + + model_a.clear_gradients() + + def check_acc(self, grad, grad_sum, acc_grad): + if grad is not None: + grad_sum = grad_sum + grad.numpy() + acc_grad = acc_grad.numpy() if acc_grad is not None else None + np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) + return grad_sum + + def print_trainer_0(self, *args): + if self.trainer_id == 0: + print(*args) + + def broadcast_param(self, param, root): + self.pg.process_group.broadcast(param, root) + return param + + def check_gradient(self, params): + other_param = [] + for param in params: + if param.trainable and (param.grad is not None): + grad = param.grad + other_grad = self.broadcast_param(grad, root=1) + if self.trainer_id == 0: + np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py b/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2cf6e1db7e00733b1b899d406cff7172a759e6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/process_group_bkcl.py @@ -0,0 +1,186 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import random +import numpy as np +import sys + +import paddle +from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.dygraph.parallel import ParallelEnv +import paddle.distributed as dist + + +def init_process_group(strategy=None): + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + pg_group = dist.init_parallel_env() + + return pg_group.process_group + + +class TestProcessGroupFp32(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + random.seed(2022) + np.random.seed(2022) + self.config() + + def config(self): + self.dtype = "float32" + self.shape = (2, 10, 5) + + def test_create_process_group_bkcl(self): + with _test_eager_guard(): + paddle.set_device( + 'xpu:%d' % paddle.distributed.ParallelEnv().dev_id + ) + + pg = init_process_group() + sys.stdout.write( + "rank {}: size {} name {}\n".format( + pg.rank(), pg.size(), pg.name() + ) + ) + sys.stdout.write( + "rank {}: test new group api ok\n".format(pg.rank()) + ) + + # test allreduce sum + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + sum_result = tensor_x + tensor_y + if pg.rank() == 0: + task = dist.all_reduce(tensor_x) + assert np.array_equal(tensor_x, sum_result) + else: + task = dist.all_reduce(tensor_y) + assert np.array_equal(tensor_y, sum_result) + + sys.stdout.write( + "rank {}: test allreduce sum api ok\n".format(pg.rank()) + ) + + # TODO + # test allreduce max/min/prod + + # test broadcast + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + # rank 1 + y = np.random.random(self.shape).astype(self.dtype) + tensor_y = paddle.to_tensor(y) + + broadcast_result = paddle.assign(tensor_x) + if pg.rank() == 0: + # XPU don't support event query by now, so just use sync op here + task = dist.broadcast(tensor_x, 0) + paddle.device.xpu.synchronize() + assert np.array_equal(broadcast_result, tensor_x) + else: + task = dist.broadcast(tensor_y, 0) + paddle.device.xpu.synchronize() + assert np.array_equal(broadcast_result, tensor_y) + + sys.stdout.write( + "rank {}: test broadcast api ok\n".format(pg.rank()) + ) + + # test barrier + # rank 0 + if pg.rank() == 0: + dist.barrier() + # rank 1 + else: + task = pg.barrier() + task.wait() + + sys.stdout.write("rank {}: test barrier api ok\n".format(pg.rank())) + + # test allgather + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + out_shape = list(self.shape) + out_shape[0] *= 2 + out = np.random.random(out_shape).astype(self.dtype) + tensor_out = paddle.to_tensor(out) + if pg.rank() == 0: + task = pg.all_gather(tensor_x, tensor_out) + task.wait() + paddle.device.xpu.synchronize() + # rank 1 + else: + tensor_out_list = [ + paddle.empty_like(tensor_x), + paddle.empty_like(tensor_x), + ] + task = dist.all_gather(tensor_out_list, tensor_y) + paddle.device.xpu.synchronize() + tensor_out = paddle.concat(tensor_out_list) + out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2]) + out_2 = paddle.slice( + tensor_out, [0], [out_shape[0] // 2], [out_shape[0]] + ) + assert np.array_equal(tensor_x, out_1) + assert np.array_equal(tensor_y, out_2) + sys.stdout.write( + "rank {}: test allgather api ok\n".format(pg.rank()) + ) + + if pg.rank() == 0: + task = pg.all_gather(tensor_x, tensor_out) + task.wait() + paddle.device.xpu.synchronize() + # rank 1 + else: + tensor_out_list = [] + task = dist.all_gather(tensor_out_list, tensor_y) + paddle.device.xpu.synchronize() + tensor_out = paddle.concat(tensor_out_list) + out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2]) + out_2 = paddle.slice( + tensor_out, [0], [out_shape[0] // 2], [out_shape[0]] + ) + assert np.array_equal(tensor_x, out_1) + assert np.array_equal(tensor_y, out_2) + sys.stdout.write( + "rank {}: test allgather api2 ok\n".format(pg.rank()) + ) + + +class TestProcessGroupFp16(TestProcessGroupFp32): + def setUp(self): + paddle.seed(2022) + random.seed(2022) + np.random.seed(2022) + self.config() + + def config(self): + self.dtype = "float16" + self.shape = (4, 20, 20) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_collective_process_group.py b/python/paddle/fluid/tests/unittests/xpu/test_collective_process_group.py new file mode 100644 index 0000000000000000000000000000000000000000..561522c9cae6ad23c005cdab1da05b40660dbaff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_collective_process_group.py @@ -0,0 +1,28 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from test_parallel_dygraph_dataparallel import TestMultipleXpus + + +class TestProcessGroup(TestMultipleXpus): + def test_process_group_bkcl(self): + self.run_mnist_2xpu('process_group_bkcl.py') + + +if __name__ == "__main__": + os.environ["BKCL_PCIE_RING"] = "1" + os.environ["BKCL_CCIX_RING"] = "0" + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/xpu/test_parallel_dygraph_dataparallel.py new file mode 100644 index 0000000000000000000000000000000000000000..3c994ba72bc1d7aa59c6917fdb5c69fc362b7dc9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_parallel_dygraph_dataparallel.py @@ -0,0 +1,158 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time +import paddle +import paddle.fluid as fluid +import copy +import os +import subprocess + +from paddle.distributed.utils.launch_utils import ( + find_free_ports, + watch_local_trainers, + get_cluster, + TrainerProc, +) + + +def get_cluster_from_args(selected_xpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_xpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_xpus) + + +def get_xpus(selected_xpus): + selected_xpus = [x.strip() for x in selected_xpus.split(',')] + return selected_xpus + + +def start_local_trainers( + cluster, + pod, + training_script, + training_script_args, + eager_mode=True, + log_dir=None, +): + current_env = copy.copy(os.environ.copy()) + # paddle broadcast ncclUniqueId use socket, and + # proxy maybe make trainers unreachable, so delete them. + # if we set them to "", grpc will log error message "bad uri" + # so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for t in pod.trainers: + proc_env = { + "PADDLE_DISTRI_BACKEND": "bkcl", + "FLAGS_selected_xpus": "%s" % ",".join([str(g) for g in t.gpus]), + "PADDLE_TRAINER_ID": "%d" % t.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + } + + if not eager_mode: + proc_env["FLAGS_enable_eager_mode"] = "%d" % 0 + + current_env.update(proc_env) + + print("trainer proc env:{}".format(current_env)) + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd = "python -m coverage run --branch -p " + training_script + else: + cmd = "python -u " + training_script + + print("start trainer proc:{} env:{}".format(cmd, proc_env)) + + fn = None + + proc = subprocess.Popen(cmd.split(" "), env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + +class TestMultipleXpus(unittest.TestCase): + def run_mnist_2xpu(self, target_file_name, eager_mode=True): + if ( + not fluid.core.is_compiled_with_xpu() + or fluid.core.get_xpu_device_count() == 0 + ): + return + + selected_xpus = get_xpus('0,1') + paddle.set_device("xpu") + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_xpus) + + procs = start_local_trainers( + cluster, + pod, + eager_mode=eager_mode, + training_script=target_file_name, + training_script_args=[], + ) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_endpoints()) + + if not alive: + print("Local procs complete, POD info:{}".format(pod)) + break + time.sleep(3) + + +class TestDataParallelWithPyLayer(TestMultipleXpus): + def test_parallel_dygraph_dataparallel_with_pylayer(self): + self.run_mnist_2xpu('parallel_dygraph_dataparallel_with_pylayer.py') + + +class TestGradientCheckInEagerMode(TestMultipleXpus): + def test_multiple_xpus_dynamic(self): + self.run_mnist_2xpu('parallel_dygraph_gradient_check_in_eager_mode.py') + + +if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" + os.environ["BKCL_PCIE_RING"] = "1" + os.environ["BKCL_CCIX_RING"] = "0" + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_xpu_stream_event.py b/python/paddle/fluid/tests/unittests/xpu/test_xpu_stream_event.py new file mode 100644 index 0000000000000000000000000000000000000000..00808c3c289f33ea8210132d0853c8deff8e7f11 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_xpu_stream_event.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.device import xpu +import paddle + +import unittest + + +class TestSynchronize(unittest.TestCase): + def test_synchronize(self): + if paddle.is_compiled_with_xpu(): + self.assertIsNone(xpu.synchronize()) + self.assertIsNone(xpu.synchronize(0)) + self.assertIsNone(xpu.synchronize(paddle.XPUPlace(0))) + + self.assertRaises(ValueError, xpu.synchronize, "xpu:0") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 5ab06a05e1284cb2e56f2236121d5aac0b87334a..08fcca5b2d8f496e55724a7637dc9852c365f6cf 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -413,6 +413,7 @@ packages=['paddle', 'paddle.autograd', 'paddle.device', 'paddle.device.cuda', + 'paddle.device.xpu', 'paddle.version', 'paddle.profiler', 'paddle.geometric',