From f438efc95b4249455a6f4d68113f9f104220ac1e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 3 Aug 2023 14:35:25 +0800 Subject: [PATCH] fix(src/atlas): add om model loading configuration GitOrigin-RevId: e0376c962b31103f779f75550c6155f86bdffa20 --- dnn/atlas-stub/include/acl/acl_mdl.h | 3 ++- src/opr/impl/atlas_runtime_op.cpp | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/dnn/atlas-stub/include/acl/acl_mdl.h b/dnn/atlas-stub/include/acl/acl_mdl.h index fb1112d5f..a5f9bff8b 100644 --- a/dnn/atlas-stub/include/acl/acl_mdl.h +++ b/dnn/atlas-stub/include/acl/acl_mdl.h @@ -76,7 +76,8 @@ typedef enum { ACL_MDL_INPUTQ_NUM_SIZET, ACL_MDL_INPUTQ_ADDR_PTR, /**< pointer to inputQ with shallow copy */ ACL_MDL_OUTPUTQ_NUM_SIZET, - ACL_MDL_OUTPUTQ_ADDR_PTR /**< pointer to outputQ with shallow copy */ + ACL_MDL_OUTPUTQ_ADDR_PTR, /**< pointer to outputQ with shallow copy */ + ACL_MDL_WORKSPACE_MEM_OPTIMIZE } aclmdlConfigAttr; typedef enum { diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index b042da26d..9e35db01d 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -178,8 +178,28 @@ AtlasRuntimeOpr::AtlasRuntimeOpr( add_input({i}); } if (m_model_id == INVALID_MODEL_ID && m_model_desc == nullptr) { - MGB_ATLAS_CHECK( - aclmdlLoadFromMem(m_buffer.data(), m_buffer.size(), &m_model_id)); + aclmdlConfigHandle* config_handle = aclmdlCreateConfigHandle(); + + size_t mdl_load_type = ACL_MDL_LOAD_FROM_MEM; + const void* mdl_mem_addr_ptr = m_buffer.data(); + size_t mdl_mem_size = m_buffer.size(); + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_LOAD_TYPE_SIZET, &mdl_load_type, + sizeof(size_t))); + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_MEM_ADDR_PTR, &mdl_mem_addr_ptr, + sizeof(const void*))); + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_MEM_SIZET, &mdl_mem_size, sizeof(size_t))); + + size_t mem_optimize_mode = 1; + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_WORKSPACE_MEM_OPTIMIZE, &mem_optimize_mode, + sizeof(size_t))); + + MGB_ATLAS_CHECK(aclmdlLoadWithConfig(config_handle, &m_model_id)); + MGB_ATLAS_CHECK(aclmdlDestroyConfigHandle(config_handle)); + m_model_desc = aclmdlCreateDesc(); MGB_ATLAS_CHECK(aclmdlGetDesc(m_model_desc, m_model_id)); m_is_model_holder = true; -- GitLab