提交 3bd0df8e 编写于 作者: M Megvii Engine Team

fix(rocm): enable var_releaser for rocm

GitOrigin-RevId: a42185aa6f3a21b79901a87facdc8c4bc0063557
上级 21c6c437
...@@ -125,7 +125,7 @@ StaticDeviceMemoryManager::make_default_impl() { ...@@ -125,7 +125,7 @@ StaticDeviceMemoryManager::make_default_impl() {
#endif // MGB_THREAD_SAFE #endif // MGB_THREAD_SAFE
/* ==================== AsyncVarReleaser ==================== */ /* ==================== AsyncVarReleaser ==================== */
#if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON || MGB_ROCM
class VarNodeMemManager::AsyncVarReleaser { class VarNodeMemManager::AsyncVarReleaser {
struct WaiterParam { struct WaiterParam {
CompNode cn; CompNode cn;
...@@ -248,7 +248,7 @@ bool VarNodeMemManager::ImpureMemPlanManager::check_need_realloc() { ...@@ -248,7 +248,7 @@ bool VarNodeMemManager::ImpureMemPlanManager::check_need_realloc() {
VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph) VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph)
: m_owner_graph(graph), : m_owner_graph(graph),
m_seq_mem_opt(graph) m_seq_mem_opt(graph)
#if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON || MGB_ROCM
,m_asyn_var_releaser(new AsyncVarReleaser) ,m_asyn_var_releaser(new AsyncVarReleaser)
#endif #endif
{ {
...@@ -256,7 +256,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph) ...@@ -256,7 +256,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph)
MGB_MARK_USED_VAR(ev); MGB_MARK_USED_VAR(ev);
// async release is only used for sync between multiple comp nodes, and // async release is only used for sync between multiple comp nodes, and
// does not wait for device to finish // does not wait for device to finish
#if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON || MGB_ROCM
m_asyn_var_releaser->wait_release_finish(); m_asyn_var_releaser->wait_release_finish();
#endif #endif
m_cpu_async_release_barrier.wait_zero(); m_cpu_async_release_barrier.wait_zero();
...@@ -298,7 +298,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph) ...@@ -298,7 +298,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph)
on_comp_seq_error); on_comp_seq_error);
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && \ #if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && \
(MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON ) (MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON || MGB_ROCM)
auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) { auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) {
m_asyn_var_releaser->wait_release_finish(); m_asyn_var_releaser->wait_release_finish();
}; };
......
...@@ -446,7 +446,7 @@ class VarNodeMemManager { ...@@ -446,7 +446,7 @@ class VarNodeMemManager {
SyncableCounter m_cpu_async_release_barrier; SyncableCounter m_cpu_async_release_barrier;
#if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON || MGB_ROCM
//! release dynamic var on after compnode event finishes //! release dynamic var on after compnode event finishes
class AsyncVarReleaser; class AsyncVarReleaser;
std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser; std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册