/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/backends/context_pool.h" #include "glog/logging.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/enforce.h" namespace phi { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) bool allow_tf32_cublas = true; void SetAllowTF32Cublas(bool active) { allow_tf32_cublas = active; } bool AllowTF32Cublas() { return allow_tf32_cublas; } bool allow_tf32_cudnn = true; void SetAllowTF32Cudnn(bool active) { allow_tf32_cudnn = active; } bool AllowTF32Cudnn() { return allow_tf32_cudnn; } #endif // PADDLE_WITH_CUDA static DeviceContextPool* pool = nullptr; DeviceContextPool& DeviceContextPool::Instance() { PADDLE_ENFORCE_NOT_NULL(pool, phi::errors::PreconditionNotMet( "Need to Create DeviceContextPool firstly!")); return *pool; } /*! \brief Create should only called by Init function */ DeviceContextPool& DeviceContextPool::Init( const std::vector& places) { if (pool == nullptr) { pool = new DeviceContextPool(places); } return *pool; } bool DeviceContextPool::IsInitialized() { return pool != nullptr; } void DeviceContextPool::SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; } thread_local const std::map>>* DeviceContextPool::external_device_contexts_ = nullptr; phi::DeviceContext* DeviceContextPool::Get(const phi::Place& place) { VLOG(6) << "DeviceContextPool Get: " << place; const std::map>>* ptr; if (external_device_contexts_ && external_device_contexts_->count(place)) { ptr = external_device_contexts_; } else { ptr = &device_contexts_; } auto it = ptr->find(place); if (it == ptr->end()) { PADDLE_THROW(phi::errors::Unimplemented( "Place %s is not supported. Please check that your paddle compiles " "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option " "or check " "that your train process set the correct device id if you use " "Executor.", place)); } return it->second.get().get(); } size_t DeviceContextPool::Size() const { if (external_device_contexts_) { return external_device_contexts_->size(); } return device_contexts_.size(); } const std::map>>& DeviceContextPool::device_contexts() const { if (external_device_contexts_) { return *external_device_contexts_; } return device_contexts_; } void DeviceContextPool::SetDeviceContexts( const std::map>>* dev_ctxs) { external_device_contexts_ = dev_ctxs; } DeviceContextPool::DeviceContextPool(const std::vector& places) { phi::memory_utils::EmplaceDeviceContexts( &device_contexts_, places, /*disable_setting_default_stream_for_allocator=*/false, /*stream_priority=*/0); } } // namespace phi