未验证 提交 e5725680 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] release device manager in py::atexit (#54932)

* [CustomDevice] release device manager in py::atexit

* fix hip_version macro

* update

* update
上级 9c17e45e
...@@ -52,12 +52,20 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool( ...@@ -52,12 +52,20 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
} }
} }
CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( std::unordered_map<
const paddle::Place& place) { std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
CustomDeviceStreamResourcePool::GetMap() {
static std::unordered_map< static std::unordered_map<
std::string, std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>> std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>
pool; pool;
return pool;
}
CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
const paddle::Place& place) {
auto& pool = GetMap();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_custom_place(place), platform::is_custom_place(place),
true, true,
...@@ -134,6 +142,16 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool( ...@@ -134,6 +142,16 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
} }
} }
std::unordered_map<std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
CustomDeviceEventResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
pool;
return pool;
}
CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
const phi::Place& place) { const phi::Place& place) {
static std::unordered_map< static std::unordered_map<
......
...@@ -31,6 +31,11 @@ using CustomDeviceEventObject = phi::event::Event; ...@@ -31,6 +31,11 @@ using CustomDeviceEventObject = phi::event::Event;
class CustomDeviceStreamResourcePool { class CustomDeviceStreamResourcePool {
public: public:
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
GetMap();
std::shared_ptr<CustomDeviceStreamObject> New(int dev_idx); std::shared_ptr<CustomDeviceStreamObject> New(int dev_idx);
static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place); static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place);
...@@ -48,6 +53,11 @@ class CustomDeviceEventResourcePool { ...@@ -48,6 +53,11 @@ class CustomDeviceEventResourcePool {
public: public:
std::shared_ptr<CustomDeviceEventObject> New(int dev_idx); std::shared_ptr<CustomDeviceEventObject> New(int dev_idx);
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
GetMap();
static CustomDeviceEventResourcePool& Instance(const paddle::Place& place); static CustomDeviceEventResourcePool& Instance(const paddle::Place& place);
private: private:
......
...@@ -26,9 +26,15 @@ namespace platform { ...@@ -26,9 +26,15 @@ namespace platform {
class CustomTracer : public TracerBase { class CustomTracer : public TracerBase {
public: public:
static CustomTracer& GetInstance(const std::string& device_type) { static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>&
GetMap() {
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>> static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>
instance; instance;
return instance;
}
static CustomTracer& GetInstance(const std::string& device_type) {
auto& instance = GetMap();
if (instance.find(device_type) == instance.cend()) { if (instance.find(device_type) == instance.cend()) {
instance.insert( instance.insert(
{device_type, std::make_shared<CustomTracer>(device_type)}); {device_type, std::make_shared<CustomTracer>(device_type)});
......
...@@ -161,6 +161,8 @@ limitations under the License. */ ...@@ -161,6 +161,8 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h"
#include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h"
#include "paddle/phi/capi/capi.h" #include "paddle/phi/capi/capi.h"
#endif #endif
...@@ -1005,6 +1007,9 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -1005,6 +1007,9 @@ PYBIND11_MODULE(libpaddle, m) {
m.def("clear_device_manager", []() { m.def("clear_device_manager", []() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::XCCLCommContext::Release(); platform::XCCLCommContext::Release();
platform::CustomTracer::GetMap().clear();
platform::CustomDeviceEventResourcePool::GetMap().clear();
platform::CustomDeviceStreamResourcePool::GetMap().clear();
phi::DeviceManager::Clear(); phi::DeviceManager::Clear();
#endif #endif
}); });
......
...@@ -97,8 +97,6 @@ class CustomDevice : public DeviceInterface { ...@@ -97,8 +97,6 @@ class CustomDevice : public DeviceInterface {
void Finalize() override { void Finalize() override {
auto devices = GetDeviceList(); auto devices = GetDeviceList();
for (auto dev_id : devices) { for (auto dev_id : devices) {
// SetDevice(dev_id);
// SynchronizeDevice(dev_id);
DeInitDevice(dev_id); DeInitDevice(dev_id);
} }
......
...@@ -671,8 +671,6 @@ DeviceManager& DeviceManager::Instance() { ...@@ -671,8 +671,6 @@ DeviceManager& DeviceManager::Instance() {
} }
void DeviceManager::Clear() { void DeviceManager::Clear() {
// TODO(wangran16): fix coredump when using npu plugin
// Instance().device_map_.clear(); // Instance().device_map_.clear();
// Instance().device_impl_map_.clear(); // Instance().device_impl_map_.clear();
} }
......
...@@ -55,7 +55,7 @@ template <> ...@@ -55,7 +55,7 @@ template <>
struct radix_key_codec_base<phi::dtype::bfloat16> struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {}; : radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4 #if HIP_VERSION >= 50400000
template <> template <>
struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {}; struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {};
......
...@@ -45,7 +45,7 @@ template <> ...@@ -45,7 +45,7 @@ template <>
struct radix_key_codec_base<phi::dtype::bfloat16> struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {}; : radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4 #if HIP_VERSION >= 50400000
template <> template <>
struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {}; struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册