未验证 提交 90d9e5ae 编写于 作者: Y Yu Yang 提交者: GitHub

feat(platform): lazy initialization of devicecontext in pool (#14067)

* feat(platform): lazy initialization of devicecontext in pool

Use std::async(deferer, []{...}) to lazy initialize DeviceContext in Pool

test=develop

* Add future includes

test=develop
上级 ab4351fe
...@@ -303,10 +303,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -303,10 +303,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
const auto dev_ctxs = for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().GetAllDeviceContexts(); platform::DeviceContextPool::Instance().Get(p)->Wait();
for (auto &dev_ctx : dev_ctxs) {
dev_ctx->Wait();
} }
if (member_->own_local_scope_) { if (member_->own_local_scope_) {
......
...@@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { ...@@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
"'Place' is not supported, Please re-compile with WITH_GPU " "'Place' is not supported, Please re-compile with WITH_GPU "
"option"); "option");
} }
return it->second.get(); return it->second.get().get();
} }
const std::vector<const DeviceContext*> template <typename DevCtx, typename PlaceType>
DeviceContextPool::GetAllDeviceContexts() const { inline void EmplaceDeviceContext(
std::vector<const DeviceContext*> all_device_ctx; std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
all_device_ctx.reserve(device_contexts_.size()); map_ptr,
for (auto& dev_ctx : device_contexts_) { platform::Place p) {
all_device_ctx.emplace_back(dev_ctx.second.get()); using PtrType = std::unique_ptr<DeviceContext>;
} map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
return all_device_ctx; // lazy evaluation. i.e., only create device context at
// first `Get`
return PtrType(new DevCtx(boost::get<PlaceType>(p)));
}));
} }
DeviceContextPool::DeviceContextPool( DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(places.size(), 0);
using PtrType = std::unique_ptr<DeviceContext>;
std::set<Place> set; std::set<Place> set;
for (auto& p : places) { for (auto& p : places) {
set.insert(p); set.insert(p);
...@@ -57,16 +59,13 @@ DeviceContextPool::DeviceContextPool( ...@@ -57,16 +59,13 @@ DeviceContextPool::DeviceContextPool(
for (auto& p : set) { for (auto& p : set) {
if (platform::is_cpu_place(p)) { if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
device_contexts_.emplace( EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
#else #else
device_contexts_.emplace( EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
#endif #endif
} else if (platform::is_gpu_place(p)) { } else if (platform::is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
device_contexts_.emplace( EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
#else #else
PADDLE_THROW( PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
...@@ -74,9 +73,8 @@ DeviceContextPool::DeviceContextPool( ...@@ -74,9 +73,8 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_cuda_pinned_place(p)) { } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
device_contexts_.emplace( EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
p, &device_contexts_, p);
PtrType(new CUDAPinnedDeviceContext(boost::get<CUDAPinnedPlace>(p))));
#else #else
PADDLE_THROW( PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
......
...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <future> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
...@@ -223,9 +224,6 @@ class DeviceContextPool { ...@@ -223,9 +224,6 @@ class DeviceContextPool {
/*! \brief Return handle of single device context. */ /*! \brief Return handle of single device context. */
platform::DeviceContext* Get(const platform::Place& place); platform::DeviceContext* Get(const platform::Place& place);
/*! \brief Return all the device contexts. */
const std::vector<const DeviceContext*> GetAllDeviceContexts() const;
template <typename Place> template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace( const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
const Place& place) { const Place& place) {
...@@ -237,7 +235,8 @@ class DeviceContextPool { ...@@ -237,7 +235,8 @@ class DeviceContextPool {
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
std::map<Place, std::unique_ptr<DeviceContext>> device_contexts_; std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册