未验证 提交 e5725680 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] release device manager in py::atexit (#54932)

* [CustomDevice] release device manager in py::atexit

* fix hip_version macro

* update

* update
上级 9c17e45e
......@@ -52,12 +52,20 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
}
}
CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
const paddle::Place& place) {
std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
CustomDeviceStreamResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>
pool;
return pool;
}
CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
const paddle::Place& place) {
auto& pool = GetMap();
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
......@@ -134,6 +142,16 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
}
}
std::unordered_map<std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
CustomDeviceEventResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
pool;
return pool;
}
CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
const phi::Place& place) {
static std::unordered_map<
......
......@@ -31,6 +31,11 @@ using CustomDeviceEventObject = phi::event::Event;
class CustomDeviceStreamResourcePool {
public:
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
GetMap();
std::shared_ptr<CustomDeviceStreamObject> New(int dev_idx);
static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place);
......@@ -48,6 +53,11 @@ class CustomDeviceEventResourcePool {
public:
std::shared_ptr<CustomDeviceEventObject> New(int dev_idx);
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
GetMap();
static CustomDeviceEventResourcePool& Instance(const paddle::Place& place);
private:
......
......@@ -26,9 +26,15 @@ namespace platform {
class CustomTracer : public TracerBase {
public:
static CustomTracer& GetInstance(const std::string& device_type) {
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>&
GetMap() {
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>
instance;
return instance;
}
static CustomTracer& GetInstance(const std::string& device_type) {
auto& instance = GetMap();
if (instance.find(device_type) == instance.cend()) {
instance.insert(
{device_type, std::make_shared<CustomTracer>(device_type)});
......
......@@ -161,6 +161,8 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/custom/custom_device_resource_pool.h"
#include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h"
#include "paddle/phi/capi/capi.h"
#endif
......@@ -1005,6 +1007,9 @@ PYBIND11_MODULE(libpaddle, m) {
m.def("clear_device_manager", []() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::XCCLCommContext::Release();
platform::CustomTracer::GetMap().clear();
platform::CustomDeviceEventResourcePool::GetMap().clear();
platform::CustomDeviceStreamResourcePool::GetMap().clear();
phi::DeviceManager::Clear();
#endif
});
......
......@@ -97,8 +97,6 @@ class CustomDevice : public DeviceInterface {
void Finalize() override {
auto devices = GetDeviceList();
for (auto dev_id : devices) {
// SetDevice(dev_id);
// SynchronizeDevice(dev_id);
DeInitDevice(dev_id);
}
......
......@@ -671,8 +671,6 @@ DeviceManager& DeviceManager::Instance() {
}
void DeviceManager::Clear() {
// TODO(wangran16): fix coredump when using npu plugin
// Instance().device_map_.clear();
// Instance().device_impl_map_.clear();
}
......
......@@ -55,7 +55,7 @@ template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4
#if HIP_VERSION >= 50400000
template <>
struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {};
......
......@@ -45,7 +45,7 @@ template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
#if ROCM_VERSION_MAJOR >= 5 && ROCM_VERSION_MINOR >= 4
#if HIP_VERSION >= 50400000
template <>
struct float_bit_mask<phi::dtype::float16> : float_bit_mask<rocprim::half> {};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册