未验证 提交 67f2c9f7 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Fix context pool sync init (#40787)

* fix context pool sync init

* add lock for insert
上级 10bab9f1
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <mutex>
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
...@@ -58,21 +60,22 @@ class DeviceContextPool { ...@@ -58,21 +60,22 @@ class DeviceContextPool {
public: public:
static DeviceContextPool& Instance(); static DeviceContextPool& Instance();
const phi::DeviceContext* Get(const Place& place) const; const phi::DeviceContext* Get(const Place& place);
phi::DeviceContext* GetMutable(const Place& place); phi::DeviceContext* GetMutable(const Place& place);
template <AllocationType T> template <AllocationType T>
const typename DefaultDeviceContextType<T>::TYPE* Get( const typename DefaultDeviceContextType<T>::TYPE* Get(const Place& place) {
const Place& place) const {
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>( return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
Get(place)); Get(place));
} }
private: private:
DeviceContextPool(); DeviceContextPool() = default;
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash> paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
context_map_; context_map_;
std::mutex mutex_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
}; };
......
...@@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() { ...@@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() {
return g_device_context_pool; return g_device_context_pool;
} }
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const { const phi::DeviceContext* DeviceContextPool::Get(const Place& place) {
auto it = context_map_.find(place); auto it = context_map_.find(place);
PADDLE_ENFORCE_NE( if (it == context_map_.end()) {
it, // only when we need the specific DeviceContext, get and cache it
context_map_.end(), auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
phi::errors::NotFound("The DeviceContext of %s does not exists.", place)); {
std::lock_guard<std::mutex> lock(mutex_);
context_map_[place] = dev_ctx;
}
return dev_ctx;
}
return it->second; return it->second;
} }
...@@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { ...@@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) {
return const_cast<phi::DeviceContext*>(Get(place)); return const_cast<phi::DeviceContext*>(Get(place));
} }
DeviceContextPool::DeviceContextPool() {
// We need to make sure that the correct value exists
// whenever we get the DeviceContext from DeviceContextPool
const auto& device_contexts =
paddle::platform::DeviceContextPool::Instance().device_contexts();
for (const auto& pair : device_contexts) {
// only get CPU and GPU DeviceContext now, add other DeviceContext type
// later if needed
if (platform::is_cpu_place(pair.first)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
||
platform::is_gpu_place(pair.first)) {
#else
) {
#endif
const phi::DeviceContext* dev_ctx = pair.second.get().get();
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", "
<< dev_ctx << "}";
context_map_[pair.first] = dev_ctx;
}
}
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册