From 0e1191f43168ace2492b65d551052f08ccb0a4fc Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 21 Mar 2022 11:04:47 +0800 Subject: [PATCH] [Phi] Add phi device context pool (#40635) * add phi device context pool * change year * fix compile error * fix operator = error * refine init impl * polish details * refine init impl --- paddle/fluid/platform/CMakeLists.txt | 2 +- paddle/fluid/platform/device_context.h | 5 ++ paddle/phi/api/include/context_pool.h | 81 ++++++++++++++++++++++++++ paddle/phi/api/lib/CMakeLists.txt | 3 +- paddle/phi/api/lib/context_pool.cc | 65 +++++++++++++++++++++ paddle/phi/api/lib/kernel_dispatch.cc | 5 +- paddle/phi/common/place.cc | 16 +++++ paddle/phi/common/place.h | 43 +++++++------- 8 files changed, 194 insertions(+), 26 deletions(-) create mode 100644 paddle/phi/api/include/context_pool.h create mode 100644 paddle/phi/api/lib/context_pool.cc diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 04c8a329e5e..de09860fd26 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -117,7 +117,7 @@ endif() cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) # seperate init from device_context to avoid cycle dependencies -cc_library(init SRCS init.cc DEPS device_context custom_kernel) +cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool) # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e104170ca24..2c5f24d28c6 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -916,6 +916,11 @@ class DeviceContextPool { size_t size() const { return device_contexts_.size(); } + const std::map>>& + device_contexts() const { + return device_contexts_; + } + private: static DeviceContextPool* pool; std::map>> diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h new file mode 100644 index 00000000000..754833a2dda --- /dev/null +++ b/paddle/phi/api/include/context_pool.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/macros.h" +#include "paddle/utils/flat_hash_map.h" + +namespace phi { +class DeviceContext; +class CPUContext; +class GPUContext; +} // namespace phi + +namespace paddle { +namespace experimental { + +template +struct DefaultDeviceContextType; + +template <> +struct DefaultDeviceContextType { + using TYPE = phi::CPUContext; +}; + +template <> +struct DefaultDeviceContextType { + using TYPE = phi::GPUContext; +}; + +/** + * The DeviceContextPool here is just a mirror of the DeviceContextPool in + * fluid, and does not manage the life cycle of the DeviceContext. + * It is mainly used for external custom operator calls and high-performance + * C++ APIs. + * + * Since DeviceContextPool in fluid is a global singleton, it always exists + * in program running, so DeviceContextPool here can always access the correct + * DeviceContext pointer. + * + * In order not to depend on the fluid's DeviceContextPool, + * the DeviceContextPool here needs to be initialized in the fluid, and cannot + * be initialized by itself. + */ +class DeviceContextPool { + public: + static DeviceContextPool& Instance(); + + const phi::DeviceContext* Get(const Place& place) const; + + phi::DeviceContext* GetMutable(const Place& place); + + template + const typename DefaultDeviceContextType::TYPE* Get( + const Place& place) const { + return reinterpret_cast::TYPE*>( + Get(place)); + } + + private: + DeviceContextPool(); + paddle::flat_hash_map + context_map_; + + DISABLE_COPY_AND_ASSIGN(DeviceContextPool); +}; + +} // namespace experimental +} // namespace paddle diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 4cbca072362..50c267f6535 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -135,8 +135,9 @@ add_custom_command( cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor_raw) cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi) +cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place) -cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory) +cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool) cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor) cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform) cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc new file mode 100644 index 00000000000..d1408a88d6f --- /dev/null +++ b/paddle/phi/api/lib/context_pool.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/api/include/context_pool.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { +namespace experimental { + +DeviceContextPool& DeviceContextPool::Instance() { + static DeviceContextPool g_device_context_pool; + return g_device_context_pool; +} + +const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const { + auto it = context_map_.find(place); + PADDLE_ENFORCE_NE( + it, + context_map_.end(), + phi::errors::NotFound("The DeviceContext of %s does not exists.", place)); + return it->second; +} + +phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { + return const_cast(Get(place)); +} + +DeviceContextPool::DeviceContextPool() { + // We need to make sure that the correct value exists + // whenever we get the DeviceContext from DeviceContextPool + const auto& device_contexts = + paddle::platform::DeviceContextPool::Instance().device_contexts(); + for (const auto& pair : device_contexts) { + // only get CPU and GPU DeviceContext now, add other DeviceContext type + // later if needed + if (platform::is_cpu_place(pair.first) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + || + platform::is_gpu_place(pair.first)) { +#else + ) { +#endif + const phi::DeviceContext* dev_ctx = pair.second.get().get(); + VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", " + << dev_ctx << "}"; + context_map_[pair.first] = dev_ctx; + } + } +} + +} // namespace experimental +} // namespace paddle diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 0e3ca1af496..5e334b9b727 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/core/compat/convert_utils.h" namespace paddle { @@ -52,8 +53,8 @@ std::size_t CountLeadingZeros(uint64_t val) { } // namespace detail phi::DeviceContext* GetDeviceContextByBackend(phi::Backend backend) { - auto& pool = paddle::platform::DeviceContextPool::Instance(); - return pool.Get(phi::TransToPhiPlace(backend)); + auto& pool = paddle::experimental::DeviceContextPool::Instance(); + return pool.GetMutable(phi::TransToPhiPlace(backend)); } DataType ParseDataType(DataType dtype) { return dtype; } diff --git a/paddle/phi/common/place.cc b/paddle/phi/common/place.cc index 644bf3679af..2b5254d3d5f 100644 --- a/paddle/phi/common/place.cc +++ b/paddle/phi/common/place.cc @@ -92,4 +92,20 @@ std::string GetGlobalDeviceType(size_t device_type_id) { return global_registered_device_type[device_type_id]; } +constexpr static int kAllocationTypeBitLength = 8; +constexpr static int kDeviceTypeIDBitLength = 8; +constexpr static int kDeviceIDBitLength = 8; + +uint32_t Place::Hash::operator()(const Place &place) const { + uint32_t hash_value = 0; + // |----31-24------|-----23-16------|-----15-08----|---7-0----| + // | For extension | AllocationType | DeviceTypeID | DeviceID | + hash_value |= (static_cast(place.alloc_type_) + << (kDeviceIDBitLength + kDeviceTypeIDBitLength)); + hash_value |= + (static_cast(place.device_type_id_) << kDeviceIDBitLength); + hash_value |= static_cast(place.device); + return hash_value; +} + } // namespace phi diff --git a/paddle/phi/common/place.h b/paddle/phi/common/place.h index 36fb910cad6..53ddd499a7e 100644 --- a/paddle/phi/common/place.h +++ b/paddle/phi/common/place.h @@ -73,31 +73,23 @@ class Place { std::string DebugString() const; + struct Hash { + // Note: Now the number of bits we need does not exceed 32 bits, so there is + // no need to use 64 bits. If needed in the future, it can be expanded, + // but now we don’t over-design. + uint32_t operator()(const Place& place) const; + }; + + uint32_t HashValue() const { return Hash()(*this); } + inline bool operator==(const Place& rhs) const { - if (alloc_type_ != rhs.GetType()) { - return false; - } - if (alloc_type_ == AllocationType::CPU || - alloc_type_ == AllocationType::GPUPINNED || - alloc_type_ == AllocationType::NPUPINNED) { - return true; - } - if (alloc_type_ == AllocationType::CUSTOM) { - return device_type_id_ == rhs.device_type_id_ && - device == rhs.GetDeviceId(); - } - return device == rhs.GetDeviceId(); + return HashValue() == rhs.HashValue(); + } + inline bool operator!=(const Place& rhs) const { + return HashValue() != rhs.HashValue(); } - inline bool operator!=(const Place& rhs) const { return !(*this == rhs); } inline bool operator<(const Place& rhs) const { - if (alloc_type_ != rhs.GetType()) { - return static_cast(alloc_type_) < static_cast(rhs.GetType()); - } - if (alloc_type_ == AllocationType::CUSTOM && - device_type_id_ != rhs.device_type_id_) { - return device_type_id_ < rhs.device_type_id_; - } - return device < rhs.GetDeviceId(); + return HashValue() < rhs.HashValue(); } public: @@ -206,3 +198,10 @@ class CustomPlace : public Place { std::ostream& operator<<(std::ostream&, const Place&); } // namespace phi + +namespace paddle { +namespace experimental { +using AllocationType = phi::AllocationType; +using Place = phi::Place; +} // namespace experimental +} // namespace paddle -- GitLab