未验证 提交 114723c9 编写于 作者: R Ruibiao Chen 提交者: GitHub

Refactor DeviceContextPool (#42901)

* Refactor DeviceContextPool

* Adjust header file order
上级 6197fbf6
...@@ -39,9 +39,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -39,9 +39,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
local_exec_scopes_(local_exec_scopes), local_exec_scopes_(local_exec_scopes),
places_(places), places_(places),
graph_(graph), graph_(graph),
fetch_ctxs_(places),
// add one more thread for generate op_deps // add one more thread for generate op_deps
prepare_pool_(1) { prepare_pool_(1) {
platform::EmplaceDeviceContexts(
&fetch_ctxs_, places,
/*disable_setting_default_stream_for_allocator=*/true);
if (ir::IsTopologySortOperationsUnique(*graph_)) { if (ir::IsTopologySortOperationsUnique(*graph_)) {
VLOG(10) VLOG(10)
<< "Change thread number to 1 because the toposort order is unique"; << "Change thread number to 1 because the toposort order is unique";
...@@ -144,7 +146,7 @@ FetchResultType FastThreadedSSAGraphExecutor::Run( ...@@ -144,7 +146,7 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
for (auto &place : places_) { for (auto &place : places_) {
fetch_ctxs_.Get(place)->Wait(); fetch_ctxs_[place].get().get()->Wait();
} }
} }
...@@ -195,7 +197,7 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -195,7 +197,7 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p)); op->SetDeviceContext(p, fetch_ctxs_[p].get().get());
} }
for (auto *var : vars) { for (auto *var : vars) {
......
...@@ -54,7 +54,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -54,7 +54,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<OpHandleBase *, int> op_deps_; std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_; std::vector<OpHandleBase *> bootstrap_ops_;
platform::DeviceContextPool fetch_ctxs_; std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>
fetch_ctxs_;
std::atomic<int> remaining_; std::atomic<int> remaining_;
std::future< std::future<
......
...@@ -32,11 +32,14 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -32,11 +32,14 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
local_scopes_(local_scopes), local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes), local_exec_scopes_(local_exec_scopes),
places_(places), places_(places),
fetch_ctxs_(places),
strategy_(strategy), strategy_(strategy),
prepare_pool_(1), prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) { : nullptr) {
platform::EmplaceDeviceContexts(
&fetch_ctxs_, places,
/*disable_setting_default_stream_for_allocator=*/true);
if (strategy_.num_iteration_per_run_ > 1) { if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0; int read_op_num = 0;
for (auto *node : graph_->Nodes()) { for (auto *node : graph_->Nodes()) {
...@@ -207,7 +210,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -207,7 +210,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p)); op->SetDeviceContext(p, fetch_ctxs_[p].get().get());
} }
for (auto *var : vars) { for (auto *var : vars) {
......
...@@ -77,7 +77,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -77,7 +77,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_exec_scopes_; std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>
fetch_ctxs_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
std::unique_ptr<OpDependentData> op_deps_; std::unique_ptr<OpDependentData> op_deps_;
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_; std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
......
...@@ -14,11 +14,31 @@ ...@@ -14,11 +14,31 @@
#include "paddle/fluid/framework/new_executor/stream_analyzer.h" #include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include <future>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::EmplaceDeviceContexts(
&d2h_ctxs_, {place},
/*disable_setting_default_stream_for_allocator=*/true);
platform::EmplaceDeviceContexts(
&h2d_ctxs_, {place},
/*disable_setting_default_stream_for_allocator=*/true);
#else
PADDLE_THROW(
platform::errors::Unimplemented("CUDAPlace is not supported. Please "
"re-compile with WITH_GPU option."));
#endif
}
}
/* /*
* Parse the var_ids that need to be associated with an event. * Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the * The caller should guarantee front_op and back_op satisfy the
...@@ -137,10 +157,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -137,10 +157,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
auto* dev_ctx = op_func_node.dev_ctx_; auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpreter::kMemcpyD2H) { if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_"; VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_); dev_ctx = d2h_ctxs_[place_].get().get();
} else if (op_type == interpreter::kMemcpyH2D) { } else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_"; VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_); dev_ctx = h2d_ctxs_[place_].get().get();
} }
return dev_ctx; return dev_ctx;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <future>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -25,15 +26,17 @@ namespace framework { ...@@ -25,15 +26,17 @@ namespace framework {
class StreamAnalyzer { class StreamAnalyzer {
public: public:
explicit StreamAnalyzer(const platform::Place& place) using Place = platform::Place;
: place_(place), d2h_ctx_pool_({place}), h2d_ctx_pool_({place}) {} using DeviceContext = platform::DeviceContext;
explicit StreamAnalyzer(const Place& place);
~StreamAnalyzer() {} ~StreamAnalyzer() {}
void Schedule(const std::vector<size_t>& downstream_ops, void Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions, size_t op_index); std::vector<Instruction>* instructions, size_t op_index);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node); DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node);
private: private:
std::vector<size_t> GetNeedEventVarIds(const Instruction& cur_instr, std::vector<size_t> GetNeedEventVarIds(const Instruction& cur_instr,
...@@ -42,16 +45,16 @@ class StreamAnalyzer { ...@@ -42,16 +45,16 @@ class StreamAnalyzer {
void ConstructEventForVar(const std::vector<size_t>& new_event_var_id, void ConstructEventForVar(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr, Instruction* next_instr,
platform::DeviceType waiter_type, platform::DeviceType waiter_type,
const platform::Place& place); const Place& place);
bool IsDirectRun(Instruction& cur_instr, // NOLINT bool IsDirectRun(Instruction& cur_instr, // NOLINT
const Instruction& next_instr); const Instruction& next_instr);
platform::DeviceType GetWaiterType(const Instruction& instr); platform::DeviceType GetWaiterType(const Instruction& instr);
platform::Place place_; Place place_;
platform::DeviceContextPool d2h_ctx_pool_; std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> d2h_ctxs_;
platform::DeviceContextPool h2d_ctx_pool_; std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> h2d_ctxs_;
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_; std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
}; };
......
...@@ -419,21 +419,12 @@ class AllocatorFacadePrivate { ...@@ -419,21 +419,12 @@ class AllocatorFacadePrivate {
const std::shared_ptr<StreamSafeCUDAAllocator>& allocator = const std::shared_ptr<StreamSafeCUDAAllocator>& allocator =
GetDefaultStreamSafeCUDAAllocator(place); GetDefaultStreamSafeCUDAAllocator(place);
// NOTE(Ruibiao): The default stream will be set when the CUDADeviceContext PADDLE_ENFORCE_EQ(
// created. Normally, the DeviceContextPool is a global singleton and one allocator->GetDefaultStream(), nullptr,
// Place only correspond to one DeviceContext. However, to support platform::errors::Unavailable(
// multi-stream scheduling, standalone executor creates two extra "The default stream for StreamSafeCUDAAllocator(%p) in %s has been "
// DeviceContextPools for H2D and D2H stream in StreamAnalyzer, which make "set to %p, not allow to change it to %p.",
// one Place correspond to multiple DeviceContext and unexpectedly reset the allocator.get(), place, allocator->GetDefaultStream(), stream));
// default stream in runtime. To avoid this behavior, we do not allow
// changing default stream after initially setting.
if (allocator->GetDefaultStream() != nullptr) {
VLOG(5) << "The default stream for StreamSafeCUDAAllocator("
<< allocator.get() << ") in " << place << " has been set to "
<< allocator->GetDefaultStream()
<< " before, not allow to change now.";
return;
}
allocator->SetDefaultStream(stream); allocator->SetDefaultStream(stream);
VLOG(8) << "Set default stream to " << stream VLOG(8) << "Set default stream to " << stream
......
...@@ -11,13 +11,21 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,13 +11,21 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <set> #include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/stream/cuda_stream.h" #include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/allocator.h"
...@@ -26,17 +34,11 @@ limitations under the License. */ ...@@ -26,17 +34,11 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/device_context.h" #include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/device/mlu/device_context_allocator.h" #include "paddle/fluid/platform/device/mlu/device_context_allocator.h"
#endif #endif
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -178,75 +180,89 @@ void DeviceContextPool::SetDeviceContexts( ...@@ -178,75 +180,89 @@ void DeviceContextPool::SetDeviceContexts(
} }
template <typename DevCtx> template <typename DevCtx>
inline void EmplaceDeviceContext( std::unique_ptr<DeviceContext> CreateDeviceContext(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>* const platform::Place& p,
map_ptr, bool disable_setting_default_stream_for_allocator = false) {
platform::Place p) {
using PtrType = std::unique_ptr<DeviceContext>; using PtrType = std::unique_ptr<DeviceContext>;
map_ptr->emplace( auto* dev_ctx = new DevCtx(p);
p, std::async(std::launch::deferred, [=] { if (is_gpu_place(p)) {
// lazy evaluation. i.e., only create device context at
// first `Get`
auto* dev_ctx = new DevCtx(p);
if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx); auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
cuda_ctx, cuda_ctx,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Failed to dynamic_cast dev_ctx into CUDADeviceContext.")); "Failed to dynamic_cast dev_ctx into CUDADeviceContext."));
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p) auto& instance = memory::allocation::AllocatorFacade::Instance();
.get()); if (!disable_setting_default_stream_for_allocator) {
dev_ctx->SetPinnedAllocator( instance.SetDefaultStream(CUDAPlace(p.GetDeviceId()), cuda_ctx->stream());
memory::allocation::AllocatorFacade::Instance() }
.GetAllocator(paddle::platform::CUDAPinnedPlace()) dev_ctx->SetAllocator(instance.GetAllocator(p).get());
.get()); dev_ctx->SetPinnedAllocator(
instance.GetAllocator(paddle::platform::CUDAPinnedPlace()).get());
cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator( cuda_ctx->PartialInitWithAllocator();
framework::DefaultCUDAGenerator(p.GetDeviceId()).get()); dev_ctx->SetGenerator(
#endif framework::DefaultCUDAGenerator(p.GetDeviceId()).get());
} else { #endif
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() } else {
.GetAllocator(p) dev_ctx->SetAllocator(
.get()); memory::allocation::AllocatorFacade::Instance().GetAllocator(p).get());
dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get()); dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
} }
dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get()); dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get());
dev_ctx->SetHostAllocator( dev_ctx->SetHostAllocator(memory::allocation::AllocatorFacade::Instance()
memory::allocation::AllocatorFacade::Instance() .GetAllocator(platform::CPUPlace())
.GetAllocator(platform::CPUPlace()) .get());
.get()); dev_ctx->SetZeroAllocator(memory::allocation::AllocatorFacade::Instance()
dev_ctx->SetZeroAllocator( .GetZeroAllocator(p)
memory::allocation::AllocatorFacade::Instance() .get());
.GetZeroAllocator(p) return PtrType(dev_ctx);
.get());
return PtrType(dev_ctx);
}));
} }
DeviceContextPool::DeviceContextPool( template <typename DevCtx>
const std::vector<platform::Place>& places) { inline void EmplaceDeviceContext(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
place_to_device_context,
platform::Place place, bool disable_setting_default_stream_for_allocator) {
// lazy evaluation. i.e., only create device context at first `Get`
place_to_device_context->emplace(
place, std::async(std::launch::deferred, CreateDeviceContext<DevCtx>,
place, disable_setting_default_stream_for_allocator));
}
void EmplaceDeviceContexts(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
place_to_device_context,
const std::vector<platform::Place>& places,
bool disable_setting_default_stream_for_allocator) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
places.size(), 0, places.size(), 0,
platform::errors::InvalidArgument("The number of platform places should " platform::errors::InvalidArgument("The number of platform places should "
"be larger than 0. But received %d.", "be larger than 0. But received %d.",
places.size())); places.size()));
std::set<Place> set; std::set<Place> set;
for (auto& p : places) { for (auto& p : places) {
set.insert(p); set.insert(p);
} }
for (auto& p : set) { for (auto& p : set) {
if (platform::is_cpu_place(p)) { if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
EmplaceDeviceContext<MKLDNNDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<MKLDNNDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
EmplaceDeviceContext<CPUDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<CPUDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#endif #endif
} else if (platform::is_gpu_place(p)) { } else if (platform::is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
EmplaceDeviceContext<CUDADeviceContext>(&device_contexts_, p); EmplaceDeviceContext<CUDADeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("CUDAPlace is not supported. Please " platform::errors::Unimplemented("CUDAPlace is not supported. Please "
...@@ -254,7 +270,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -254,7 +270,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_cuda_pinned_place(p)) { } else if (platform::is_cuda_pinned_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
EmplaceDeviceContext<CUDAPinnedDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<CUDAPinnedDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported. Please re-compile with WITH_GPU " "CUDAPlace is not supported. Please re-compile with WITH_GPU "
...@@ -262,7 +280,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -262,7 +280,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_xpu_place(p)) { } else if (platform::is_xpu_place(p)) {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
EmplaceDeviceContext<XPUDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<XPUDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("XPUPlace is not supported. Please " platform::errors::Unimplemented("XPUPlace is not supported. Please "
...@@ -270,7 +290,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -270,7 +290,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_mlu_place(p)) { } else if (platform::is_mlu_place(p)) {
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
EmplaceDeviceContext<MLUDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<MLUDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("MLUPlace is not supported. Please " platform::errors::Unimplemented("MLUPlace is not supported. Please "
...@@ -278,7 +300,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -278,7 +300,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_ipu_place(p)) { } else if (platform::is_ipu_place(p)) {
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
EmplaceDeviceContext<IPUDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<IPUDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("IPUPlace is not supported. Please " platform::errors::Unimplemented("IPUPlace is not supported. Please "
...@@ -286,7 +310,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -286,7 +310,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_npu_place(p)) { } else if (platform::is_npu_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
EmplaceDeviceContext<NPUDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<NPUDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported. Please " "NPUPlace is not supported. Please "
...@@ -294,7 +320,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -294,7 +320,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_npu_pinned_place(p)) { } else if (platform::is_npu_pinned_place(p)) {
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
EmplaceDeviceContext<NPUPinnedDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<NPUPinnedDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"NPUPinnedPlace is not supported. Please re-compile with " "NPUPinnedPlace is not supported. Please re-compile with "
...@@ -303,7 +331,9 @@ DeviceContextPool::DeviceContextPool( ...@@ -303,7 +331,9 @@ DeviceContextPool::DeviceContextPool(
#endif #endif
} else if (platform::is_custom_place(p)) { } else if (platform::is_custom_place(p)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
EmplaceDeviceContext<CustomDeviceContext>(&device_contexts_, p); EmplaceDeviceContext<CustomDeviceContext>(
place_to_device_context, p,
disable_setting_default_stream_for_allocator);
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"CustomPlace is not supported. Please re-compile with " "CustomPlace is not supported. Please re-compile with "
...@@ -314,6 +344,12 @@ DeviceContextPool::DeviceContextPool( ...@@ -314,6 +344,12 @@ DeviceContextPool::DeviceContextPool(
} }
} }
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
EmplaceDeviceContexts(&device_contexts_, places,
/*disable_setting_default_stream_for_allocator=*/false);
}
CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() { CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() {
phi::CPUContext::Init(); phi::CPUContext::Init();
} }
...@@ -556,10 +592,6 @@ CUDAContext::~CUDAContext() { ...@@ -556,10 +592,6 @@ CUDAContext::~CUDAContext() {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
phi::GPUContext::PartialInitWithoutAllocator(); phi::GPUContext::PartialInitWithoutAllocator();
cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place)); cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
auto& instance = memory::allocation::AllocatorFacade::Instance();
instance.SetDefaultStream(place, phi::GPUContext::stream());
workspace_.reset(new phi::DnnWorkspaceHandle(
instance.GetAllocator(place).get(), stream()));
} }
CUDADeviceContext::~CUDADeviceContext() = default; CUDADeviceContext::~CUDADeviceContext() = default;
......
...@@ -645,7 +645,6 @@ class CUDADeviceContext : public phi::GPUContext { ...@@ -645,7 +645,6 @@ class CUDADeviceContext : public phi::GPUContext {
// NOTE: Just for compatibility with the past, please delete if there is an // NOTE: Just for compatibility with the past, please delete if there is an
// elegant way. // elegant way.
std::unique_ptr<stream::CUDAStream> cuda_stream_; std::unique_ptr<stream::CUDAStream> cuda_stream_;
std::unique_ptr<phi::DnnWorkspaceHandle> workspace_{nullptr};
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
}; };
...@@ -883,11 +882,15 @@ struct DefaultDeviceContextType<platform::CustomPlace> { ...@@ -883,11 +882,15 @@ struct DefaultDeviceContextType<platform::CustomPlace> {
}; };
#endif #endif
void EmplaceDeviceContexts(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
place_to_device_context,
const std::vector<platform::Place>& places,
bool disable_setting_default_stream_for_allocator);
/*! \brief device context pool singleton */ /*! \brief device context pool singleton */
class DeviceContextPool { class DeviceContextPool {
public: public:
explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool& Instance() { static DeviceContextPool& Instance() {
PADDLE_ENFORCE_NOT_NULL(pool, PADDLE_ENFORCE_NOT_NULL(pool,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -925,6 +928,8 @@ class DeviceContextPool { ...@@ -925,6 +928,8 @@ class DeviceContextPool {
std::shared_future<std::unique_ptr<DeviceContext>>>*); std::shared_future<std::unique_ptr<DeviceContext>>>*);
private: private:
explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool* pool; static DeviceContextPool* pool;
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
device_contexts_; device_contexts_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册