diff --git a/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc b/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc index 1cd6c3bb3f7458cdb512d53c52fcfecef43dce91..5d0fb95d8ea44272bc9c3275237256280e6a7876 100644 --- a/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc +++ b/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc @@ -52,12 +52,20 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool( } } -CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( - const paddle::Place& place) { +std::unordered_map< + std::string, + std::vector>>& +CustomDeviceStreamResourcePool::GetMap() { static std::unordered_map< std::string, std::vector>> pool; + return pool; +} + +CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( + const paddle::Place& place) { + auto& pool = GetMap(); PADDLE_ENFORCE_EQ( platform::is_custom_place(place), true, @@ -134,6 +142,16 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool( } } +std::unordered_map>>& +CustomDeviceEventResourcePool::GetMap() { + static std::unordered_map< + std::string, + std::vector>> + pool; + return pool; +} + CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( const phi::Place& place) { static std::unordered_map< diff --git a/paddle/fluid/platform/device/custom/custom_device_resource_pool.h b/paddle/fluid/platform/device/custom/custom_device_resource_pool.h index c643cff7b54512257e11df9ab4cd2cbfaf836c9c..749e696215a06f24bbc06c24c7de26c446c96a1e 100644 --- a/paddle/fluid/platform/device/custom/custom_device_resource_pool.h +++ b/paddle/fluid/platform/device/custom/custom_device_resource_pool.h @@ -31,6 +31,11 @@ using CustomDeviceEventObject = phi::event::Event; class CustomDeviceStreamResourcePool { public: + static std::unordered_map< + std::string, + std::vector>>& + GetMap(); + std::shared_ptr New(int dev_idx); static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place); @@ -48,6 +53,11 @@ class CustomDeviceEventResourcePool { public: std::shared_ptr New(int dev_idx); + static std::unordered_map< + std::string, + std::vector>>& + GetMap(); + static CustomDeviceEventResourcePool& Instance(const paddle::Place& place); private: diff --git a/paddle/fluid/platform/profiler/custom_device/custom_tracer.h b/paddle/fluid/platform/profiler/custom_device/custom_tracer.h index d70f92588d5bde8fcea2a58c72823fd34f5a0b95..5d5874eabe50ea5ae45e6374decda6afe5153f70 100644 --- a/paddle/fluid/platform/profiler/custom_device/custom_tracer.h +++ b/paddle/fluid/platform/profiler/custom_device/custom_tracer.h @@ -26,9 +26,15 @@ namespace platform { class CustomTracer : public TracerBase { public: - static CustomTracer& GetInstance(const std::string& device_type) { + static std::unordered_map>& + GetMap() { static std::unordered_map> instance; + return instance; + } + + static CustomTracer& GetInstance(const std::string& device_type) { + auto& instance = GetMap(); if (instance.find(device_type) == instance.cend()) { instance.insert( {device_type, std::make_shared(device_type)}); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fe093c165adcd03aa1a7d66f42bc5292dceec306..9848985519d8e5f38aee9397644464ad53310f9b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -161,6 +161,8 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h" +#include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h" #include "paddle/phi/capi/capi.h" #endif @@ -1005,6 +1007,9 @@ PYBIND11_MODULE(libpaddle, m) { m.def("clear_device_manager", []() { #ifdef PADDLE_WITH_CUSTOM_DEVICE platform::XCCLCommContext::Release(); + platform::CustomTracer::GetMap().clear(); + platform::CustomDeviceEventResourcePool::GetMap().clear(); + platform::CustomDeviceStreamResourcePool::GetMap().clear(); phi::DeviceManager::Clear(); #endif }); diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 4c6c1b9a496b48fdc05afb5469efefe42f967f98..cd4148b94ebe6c3af52ed34b3122e56ba019ddab 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -97,8 +97,6 @@ class CustomDevice : public DeviceInterface { void Finalize() override { auto devices = GetDeviceList(); for (auto dev_id : devices) { - // SetDevice(dev_id); - // SynchronizeDevice(dev_id); DeInitDevice(dev_id); } diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 4d01b2aec4dcacc0a2f587a5bb4ced40c10f924e..902395a162045bd061a535ceddd4d3d93712a77e 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -671,8 +671,6 @@ DeviceManager& DeviceManager::Instance() { } void DeviceManager::Clear() { - // TODO(wangran16): fix coredump when using npu plugin - // Instance().device_map_.clear(); // Instance().device_impl_map_.clear(); } diff --git a/paddle/phi/kernels/funcs/top_k_function_cuda.h b/paddle/phi/kernels/funcs/top_k_function_cuda.h index 4b89bdb5b1b74822d485d66e2e5ebc4c8637cf2c..b6d6b0cffc667f6d91267fa176ad2dfc306c2b7a 100644 --- a/paddle/phi/kernels/funcs/top_k_function_cuda.h +++ b/paddle/phi/kernels/funcs/top_k_function_cuda.h @@ -55,7 +55,7 @@ template <> struct radix_key_codec_base : radix_key_codec_integral {}; -#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4 +#if HIP_VERSION >= 50400000 template <> struct float_bit_mask : float_bit_mask {}; diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 5942ffbc4289932227abc5d011e3a35d14231a73..5102594f98d1e02d86ec4a38015214a9e8340081 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -45,7 +45,7 @@ template <> struct radix_key_codec_base : radix_key_codec_integral {}; -#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4 +#if HIP_VERSION >= 50400000 template <> struct float_bit_mask : float_bit_mask {};