未验证 提交 67f2c9f7 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Fix context pool sync init (#40787)

* fix context pool sync init

* add lock for insert
上级 10bab9f1
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <mutex>
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/flat_hash_map.h"
......@@ -58,21 +60,22 @@ class DeviceContextPool {
public:
static DeviceContextPool& Instance();
const phi::DeviceContext* Get(const Place& place) const;
const phi::DeviceContext* Get(const Place& place);
phi::DeviceContext* GetMutable(const Place& place);
template <AllocationType T>
const typename DefaultDeviceContextType<T>::TYPE* Get(
const Place& place) const {
const typename DefaultDeviceContextType<T>::TYPE* Get(const Place& place) {
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
Get(place));
}
private:
DeviceContextPool();
DeviceContextPool() = default;
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
context_map_;
std::mutex mutex_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
......
......@@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() {
return g_device_context_pool;
}
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const {
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) {
auto it = context_map_.find(place);
PADDLE_ENFORCE_NE(
it,
context_map_.end(),
phi::errors::NotFound("The DeviceContext of %s does not exists.", place));
if (it == context_map_.end()) {
// only when we need the specific DeviceContext, get and cache it
auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
{
std::lock_guard<std::mutex> lock(mutex_);
context_map_[place] = dev_ctx;
}
return dev_ctx;
}
return it->second;
}
......@@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) {
return const_cast<phi::DeviceContext*>(Get(place));
}
DeviceContextPool::DeviceContextPool() {
// We need to make sure that the correct value exists
// whenever we get the DeviceContext from DeviceContextPool
const auto& device_contexts =
paddle::platform::DeviceContextPool::Instance().device_contexts();
for (const auto& pair : device_contexts) {
// only get CPU and GPU DeviceContext now, add other DeviceContext type
// later if needed
if (platform::is_cpu_place(pair.first)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
||
platform::is_gpu_place(pair.first)) {
#else
) {
#endif
const phi::DeviceContext* dev_ctx = pair.second.get().get();
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", "
<< dev_ctx << "}";
context_map_[pair.first] = dev_ctx;
}
}
}
} // namespace experimental
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册