提交 15dd5e1a 编写于 作者: M Megvii Engine Team

fix(mgb/core): fix memory management release cambricon var issue

GitOrigin-RevId: abf881978c8fd85b3cfa823fea13045ff06fb88e
上级 38ea5f1b
...@@ -513,6 +513,7 @@ if(MGE_WITH_ATLAS) ...@@ -513,6 +513,7 @@ if(MGE_WITH_ATLAS)
set(MGB_ATLAS ${MGE_WITH_ATLAS}) set(MGB_ATLAS ${MGE_WITH_ATLAS})
endif() endif()
find_program(CCACHE_BIN ccache) find_program(CCACHE_BIN ccache)
if(CCACHE_BIN) if(CCACHE_BIN)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_BIN}) set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_BIN})
...@@ -688,6 +689,7 @@ if(MGE_ARCH STREQUAL "aarch64") ...@@ -688,6 +689,7 @@ if(MGE_ARCH STREQUAL "aarch64")
set(MEGDNN_AARCH64 1) set(MEGDNN_AARCH64 1)
set(MEGDNN_64_BIT 1) set(MEGDNN_64_BIT 1)
set(MARCH "-march=armv8-a") set(MARCH "-march=armv8-a")
set(MGB_AARCH64 1)
if(MGE_ARMV8_2_FEATURE_FP16) if(MGE_ARMV8_2_FEATURE_FP16)
message(STATUS "Enable fp16 feature support in armv8.2") message(STATUS "Enable fp16 feature support in armv8.2")
if(NOT ${MGE_DISABLE_FLOAT16}) if(NOT ${MGE_DISABLE_FLOAT16})
......
...@@ -177,11 +177,11 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { ...@@ -177,11 +177,11 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
dev_type = DeviceType::CAMBRICON; dev_type = DeviceType::CAMBRICON;
ptr += 9; ptr += 9;
} else if (ptr[0] == 'm') { } else if (ptr[0] == 'm') {
if (strncmp(ptr, "multithread", 11)) { if (strncmp(ptr, "multithread", 11)) {
err(); err();
} }
dev_type = DeviceType::MULTITHREAD; dev_type = DeviceType::MULTITHREAD;
ptr += 11; ptr += 11;
} else { } else {
if (ptr[1] != 'p' || ptr[2] != 'u') { if (ptr[1] != 'p' || ptr[2] != 'u') {
err(); err();
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "megcore_atlas.h" #include "megcore_atlas.h"
#endif #endif
using namespace mgb; using namespace mgb;
/* =================== MegDNNHandle =================== */ /* =================== MegDNNHandle =================== */
...@@ -101,6 +102,7 @@ MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) { ...@@ -101,6 +102,7 @@ MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) {
} }
#endif #endif
if (env.property().type == CompNode::DeviceType::CPU) { if (env.property().type == CompNode::DeviceType::CPU) {
megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCPU); megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCPU);
megcoreCreateComputingHandleWithCPUDispatcher(&m_comp_hdl, m_dev_hdl, megcoreCreateComputingHandleWithCPUDispatcher(&m_comp_hdl, m_dev_hdl,
...@@ -254,6 +256,7 @@ void CompNodeEnv::init_atlas(CompNode comp_node, const AtlasEnv& env) { ...@@ -254,6 +256,7 @@ void CompNodeEnv::init_atlas(CompNode comp_node, const AtlasEnv& env) {
#endif #endif
#if MGB_ROCM #if MGB_ROCM
void mgb::_on_hip_error(const char* expr, hipError_t err, const char* file, void mgb::_on_hip_error(const char* expr, hipError_t err, const char* file,
......
...@@ -77,6 +77,7 @@ AtlasError::AtlasError(const std::string &msg): ...@@ -77,6 +77,7 @@ AtlasError::AtlasError(const std::string &msg):
} }
ROCmError::ROCmError(const std::string &msg): ROCmError::ROCmError(const std::string &msg):
SystemError(msg) SystemError(msg)
{ {
......
...@@ -125,7 +125,7 @@ StaticDeviceMemoryManager::make_default_impl() { ...@@ -125,7 +125,7 @@ StaticDeviceMemoryManager::make_default_impl() {
#endif // MGB_THREAD_SAFE #endif // MGB_THREAD_SAFE
/* ==================== AsyncVarReleaser ==================== */ /* ==================== AsyncVarReleaser ==================== */
#if MGB_CUDA || MGB_ATLAS #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON
class VarNodeMemManager::AsyncVarReleaser { class VarNodeMemManager::AsyncVarReleaser {
struct WaiterParam { struct WaiterParam {
CompNode cn; CompNode cn;
...@@ -245,18 +245,18 @@ bool VarNodeMemManager::ImpureMemPlanManager::check_need_realloc() { ...@@ -245,18 +245,18 @@ bool VarNodeMemManager::ImpureMemPlanManager::check_need_realloc() {
} }
/* ==================== VarNodeMemManager ==================== */ /* ==================== VarNodeMemManager ==================== */
VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl* graph)
m_owner_graph(graph), : m_owner_graph(graph),
m_seq_mem_opt(graph) m_seq_mem_opt(graph)
#if MGB_CUDA || MGB_ATLAS #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON
,m_asyn_var_releaser(new AsyncVarReleaser) ,m_asyn_var_releaser(new AsyncVarReleaser)
#endif #endif
{ {
auto on_comp_seq_finish = [this](const event::CompSeqExecFinished& ev) { auto on_comp_seq_finish = [this](const event::CompSeqExecFinished& ev) {
MGB_MARK_USED_VAR(ev); MGB_MARK_USED_VAR(ev);
// async release is only used for sync between multiple comp nodes, and // async release is only used for sync between multiple comp nodes, and
// does not wait for device to finish // does not wait for device to finish
#if MGB_CUDA || MGB_ATLAS #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON
m_asyn_var_releaser->wait_release_finish(); m_asyn_var_releaser->wait_release_finish();
#endif #endif
m_cpu_async_release_barrier.wait_zero(); m_cpu_async_release_barrier.wait_zero();
...@@ -297,7 +297,8 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): ...@@ -297,7 +297,8 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph):
graph->event().register_receiver_permanent<event::CompSeqExecError>( graph->event().register_receiver_permanent<event::CompSeqExecError>(
on_comp_seq_error); on_comp_seq_error);
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && (MGB_CUDA || MGB_ATLAS) #if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && \
(MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON )
auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) { auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) {
m_asyn_var_releaser->wait_release_finish(); m_asyn_var_releaser->wait_release_finish();
}; };
...@@ -1448,6 +1449,13 @@ void VarNodeMemManager::decr_var_mem_refcnt( ...@@ -1448,6 +1449,13 @@ void VarNodeMemManager::decr_var_mem_refcnt(
m_asyn_var_releaser->add(dispatch_cn, var); m_asyn_var_releaser->add(dispatch_cn, var);
break; break;
} }
#endif
#if MGB_CAMBRICON
case DT::CAMBRICON:
{
m_asyn_var_releaser->add(dispatch_cn, var);
break;
}
#endif #endif
default: default:
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
......
...@@ -446,7 +446,7 @@ class VarNodeMemManager { ...@@ -446,7 +446,7 @@ class VarNodeMemManager {
SyncableCounter m_cpu_async_release_barrier; SyncableCounter m_cpu_async_release_barrier;
#if MGB_CUDA || MGB_ATLAS #if MGB_CUDA || MGB_ATLAS || MGB_CAMBRICON
//! release dynamic var on after compnode event finishes //! release dynamic var on after compnode event finishes
class AsyncVarReleaser; class AsyncVarReleaser;
std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser; std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser;
......
...@@ -90,6 +90,7 @@ ...@@ -90,6 +90,7 @@
#endif // MGB_ATLAS #endif // MGB_ATLAS
#if MGB_ROCM #if MGB_ROCM
#include "hcc_detail/hcc_defs_prologue.h" #include "hcc_detail/hcc_defs_prologue.h"
#include "megcore_rocm.h" #include "megcore_rocm.h"
...@@ -194,6 +195,7 @@ namespace mgb { ...@@ -194,6 +195,7 @@ namespace mgb {
const char* file, const char* func, int line); const char* file, const char* func, int line);
#endif #endif
#if MGB_CUDA #if MGB_CUDA
[[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err, [[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err,
const char* file, const char* func, int line); const char* file, const char* func, int line);
...@@ -325,6 +327,7 @@ public: ...@@ -325,6 +327,7 @@ public:
} }
#endif #endif
} }
/*! /*!
...@@ -426,6 +429,8 @@ public: ...@@ -426,6 +429,8 @@ public:
void init_atlas(CompNode comp_node, const AtlasEnv& env); void init_atlas(CompNode comp_node, const AtlasEnv& env);
#endif #endif
#if MGB_ROCM #if MGB_ROCM
struct ROCmEnv { struct ROCmEnv {
int device = -1; int device = -1;
...@@ -485,9 +490,7 @@ public: ...@@ -485,9 +490,7 @@ public:
}; };
static InitStatus init_status; static InitStatus init_status;
static void init() { static void init() { init_status.init(); }
init_status.init();
}
void activate() const { void activate() const {
init(); init();
......
...@@ -62,6 +62,7 @@ TEST(TestCompNode, Parse) { ...@@ -62,6 +62,7 @@ TEST(TestCompNode, Parse) {
ASSERT_EQ(L::parse("multithread:default:2"), ASSERT_EQ(L::parse("multithread:default:2"),
make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2)); make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2));
ASSERT_THROW(L::parse("apu"), MegBrainError); ASSERT_THROW(L::parse("apu"), MegBrainError);
ASSERT_THROW(L::parse("fpgbx"), MegBrainError); ASSERT_THROW(L::parse("fpgbx"), MegBrainError);
ASSERT_THROW(L::parse("cab0"), MegBrainError); ASSERT_THROW(L::parse("cab0"), MegBrainError);
...@@ -149,6 +150,7 @@ TEST(TestCompNode, Load) { ...@@ -149,6 +150,7 @@ TEST(TestCompNode, Load) {
auto atlas1 = CompNode::load("atlas1"); auto atlas1 = CompNode::load("atlas1");
ASSERT_NE(atlas0, atlas1); ASSERT_NE(atlas0, atlas1);
#endif #endif
} }
TEST(TestCompNode, FreeAfterFinalize) { TEST(TestCompNode, FreeAfterFinalize) {
...@@ -762,6 +764,7 @@ TEST(TestCompNodeAtlas, D2DCopy) { ...@@ -762,6 +764,7 @@ TEST(TestCompNodeAtlas, D2DCopy) {
} }
#endif #endif
namespace { namespace {
class CompNodeDepedentObjectInst final : public CompNodeDepedentObject { class CompNodeDepedentObjectInst final : public CompNodeDepedentObject {
int *m_dst, *m_timer; int *m_dst, *m_timer;
......
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#cmakedefine01 MGB_ENABLE_OPR_MM #cmakedefine01 MGB_ENABLE_OPR_MM
#cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION #cmakedefine01 MGB_ENABLE_FBS_SERIALIZATION
#cmakedefine01 MGB_IS_DEV #cmakedefine01 MGB_IS_DEV
// DNN related flags // DNN related flags
// Platform macro's // Platform macro's
#cmakedefine01 MEGDNN_WITH_CUDA #cmakedefine01 MEGDNN_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册