未验证 提交 0e1191f4 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Add phi device context pool (#40635)

* add phi device context pool

* change year

* fix compile error

* fix operator = error

* refine init impl

* polish details

* refine init impl
上级 276017bb
...@@ -117,7 +117,7 @@ endif() ...@@ -117,7 +117,7 @@ endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# seperate init from device_context to avoid cycle dependencies # seperate init from device_context to avoid cycle dependencies
cc_library(init SRCS init.cc DEPS device_context custom_kernel) cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool)
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
......
...@@ -916,6 +916,11 @@ class DeviceContextPool { ...@@ -916,6 +916,11 @@ class DeviceContextPool {
size_t size() const { return device_contexts_.size(); } size_t size() const { return device_contexts_.size(); }
const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
device_contexts() const {
return device_contexts_;
}
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/flat_hash_map.h"
namespace phi {
class DeviceContext;
class CPUContext;
class GPUContext;
} // namespace phi
namespace paddle {
namespace experimental {
template <AllocationType T>
struct DefaultDeviceContextType;
template <>
struct DefaultDeviceContextType<AllocationType::CPU> {
using TYPE = phi::CPUContext;
};
template <>
struct DefaultDeviceContextType<AllocationType::GPU> {
using TYPE = phi::GPUContext;
};
/**
* The DeviceContextPool here is just a mirror of the DeviceContextPool in
* fluid, and does not manage the life cycle of the DeviceContext.
* It is mainly used for external custom operator calls and high-performance
* C++ APIs.
*
* Since DeviceContextPool in fluid is a global singleton, it always exists
* in program running, so DeviceContextPool here can always access the correct
* DeviceContext pointer.
*
* In order not to depend on the fluid's DeviceContextPool,
* the DeviceContextPool here needs to be initialized in the fluid, and cannot
* be initialized by itself.
*/
class DeviceContextPool {
public:
static DeviceContextPool& Instance();
const phi::DeviceContext* Get(const Place& place) const;
phi::DeviceContext* GetMutable(const Place& place);
template <AllocationType T>
const typename DefaultDeviceContextType<T>::TYPE* Get(
const Place& place) const {
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
Get(place));
}
private:
DeviceContextPool();
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
context_map_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
} // namespace experimental
} // namespace paddle
...@@ -135,8 +135,9 @@ add_custom_command( ...@@ -135,8 +135,9 @@ add_custom_command(
cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor_raw) cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor_raw)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi) cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi)
cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place)
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory) cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool)
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor) cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform) cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace experimental {
DeviceContextPool& DeviceContextPool::Instance() {
static DeviceContextPool g_device_context_pool;
return g_device_context_pool;
}
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const {
auto it = context_map_.find(place);
PADDLE_ENFORCE_NE(
it,
context_map_.end(),
phi::errors::NotFound("The DeviceContext of %s does not exists.", place));
return it->second;
}
phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) {
return const_cast<phi::DeviceContext*>(Get(place));
}
DeviceContextPool::DeviceContextPool() {
// We need to make sure that the correct value exists
// whenever we get the DeviceContext from DeviceContextPool
const auto& device_contexts =
paddle::platform::DeviceContextPool::Instance().device_contexts();
for (const auto& pair : device_contexts) {
// only get CPU and GPU DeviceContext now, add other DeviceContext type
// later if needed
if (platform::is_cpu_place(pair.first)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
||
platform::is_gpu_place(pair.first)) {
#else
) {
#endif
const phi::DeviceContext* dev_ctx = pair.second.get().get();
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", "
<< dev_ctx << "}";
context_map_[pair.first] = dev_ctx;
}
}
}
} // namespace experimental
} // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
namespace paddle { namespace paddle {
...@@ -52,8 +53,8 @@ std::size_t CountLeadingZeros(uint64_t val) { ...@@ -52,8 +53,8 @@ std::size_t CountLeadingZeros(uint64_t val) {
} // namespace detail } // namespace detail
phi::DeviceContext* GetDeviceContextByBackend(phi::Backend backend) { phi::DeviceContext* GetDeviceContextByBackend(phi::Backend backend) {
auto& pool = paddle::platform::DeviceContextPool::Instance(); auto& pool = paddle::experimental::DeviceContextPool::Instance();
return pool.Get(phi::TransToPhiPlace(backend)); return pool.GetMutable(phi::TransToPhiPlace(backend));
} }
DataType ParseDataType(DataType dtype) { return dtype; } DataType ParseDataType(DataType dtype) { return dtype; }
......
...@@ -92,4 +92,20 @@ std::string GetGlobalDeviceType(size_t device_type_id) { ...@@ -92,4 +92,20 @@ std::string GetGlobalDeviceType(size_t device_type_id) {
return global_registered_device_type[device_type_id]; return global_registered_device_type[device_type_id];
} }
constexpr static int kAllocationTypeBitLength = 8;
constexpr static int kDeviceTypeIDBitLength = 8;
constexpr static int kDeviceIDBitLength = 8;
uint32_t Place::Hash::operator()(const Place &place) const {
uint32_t hash_value = 0;
// |----31-24------|-----23-16------|-----15-08----|---7-0----|
// | For extension | AllocationType | DeviceTypeID | DeviceID |
hash_value |= (static_cast<uint8_t>(place.alloc_type_)
<< (kDeviceIDBitLength + kDeviceTypeIDBitLength));
hash_value |=
(static_cast<uint8_t>(place.device_type_id_) << kDeviceIDBitLength);
hash_value |= static_cast<uint8_t>(place.device);
return hash_value;
}
} // namespace phi } // namespace phi
...@@ -73,31 +73,23 @@ class Place { ...@@ -73,31 +73,23 @@ class Place {
std::string DebugString() const; std::string DebugString() const;
struct Hash {
// Note: Now the number of bits we need does not exceed 32 bits, so there is
// no need to use 64 bits. If needed in the future, it can be expanded,
// but now we don’t over-design.
uint32_t operator()(const Place& place) const;
};
uint32_t HashValue() const { return Hash()(*this); }
inline bool operator==(const Place& rhs) const { inline bool operator==(const Place& rhs) const {
if (alloc_type_ != rhs.GetType()) { return HashValue() == rhs.HashValue();
return false; }
} inline bool operator!=(const Place& rhs) const {
if (alloc_type_ == AllocationType::CPU || return HashValue() != rhs.HashValue();
alloc_type_ == AllocationType::GPUPINNED ||
alloc_type_ == AllocationType::NPUPINNED) {
return true;
}
if (alloc_type_ == AllocationType::CUSTOM) {
return device_type_id_ == rhs.device_type_id_ &&
device == rhs.GetDeviceId();
}
return device == rhs.GetDeviceId();
} }
inline bool operator!=(const Place& rhs) const { return !(*this == rhs); }
inline bool operator<(const Place& rhs) const { inline bool operator<(const Place& rhs) const {
if (alloc_type_ != rhs.GetType()) { return HashValue() < rhs.HashValue();
return static_cast<int>(alloc_type_) < static_cast<int>(rhs.GetType());
}
if (alloc_type_ == AllocationType::CUSTOM &&
device_type_id_ != rhs.device_type_id_) {
return device_type_id_ < rhs.device_type_id_;
}
return device < rhs.GetDeviceId();
} }
public: public:
...@@ -206,3 +198,10 @@ class CustomPlace : public Place { ...@@ -206,3 +198,10 @@ class CustomPlace : public Place {
std::ostream& operator<<(std::ostream&, const Place&); std::ostream& operator<<(std::ostream&, const Place&);
} // namespace phi } // namespace phi
namespace paddle {
namespace experimental {
using AllocationType = phi::AllocationType;
using Place = phi::Place;
} // namespace experimental
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册