diff --git a/dnn/atlas-stub/include/acl/acl_mdl.h b/dnn/atlas-stub/include/acl/acl_mdl.h index fb1112d5f4af00b0760a441882078dfc31e6fc26..a5f9bff8b558cadcfa2702f5be80015cad62a6af 100644 --- a/dnn/atlas-stub/include/acl/acl_mdl.h +++ b/dnn/atlas-stub/include/acl/acl_mdl.h @@ -76,7 +76,8 @@ typedef enum { ACL_MDL_INPUTQ_NUM_SIZET, ACL_MDL_INPUTQ_ADDR_PTR, /**< pointer to inputQ with shallow copy */ ACL_MDL_OUTPUTQ_NUM_SIZET, - ACL_MDL_OUTPUTQ_ADDR_PTR /**< pointer to outputQ with shallow copy */ + ACL_MDL_OUTPUTQ_ADDR_PTR, /**< pointer to outputQ with shallow copy */ + ACL_MDL_WORKSPACE_MEM_OPTIMIZE } aclmdlConfigAttr; typedef enum { diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index b042da26d143e9018f9d3a7a9325a69b42df6369..9e35db01df80c38a7fa22b68e63e877c0800a59e 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -178,8 +178,28 @@ AtlasRuntimeOpr::AtlasRuntimeOpr( add_input({i}); } if (m_model_id == INVALID_MODEL_ID && m_model_desc == nullptr) { - MGB_ATLAS_CHECK( - aclmdlLoadFromMem(m_buffer.data(), m_buffer.size(), &m_model_id)); + aclmdlConfigHandle* config_handle = aclmdlCreateConfigHandle(); + + size_t mdl_load_type = ACL_MDL_LOAD_FROM_MEM; + const void* mdl_mem_addr_ptr = m_buffer.data(); + size_t mdl_mem_size = m_buffer.size(); + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_LOAD_TYPE_SIZET, &mdl_load_type, + sizeof(size_t))); + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_MEM_ADDR_PTR, &mdl_mem_addr_ptr, + sizeof(const void*))); + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_MEM_SIZET, &mdl_mem_size, sizeof(size_t))); + + size_t mem_optimize_mode = 1; + MGB_ATLAS_CHECK(aclmdlSetConfigOpt( + config_handle, ACL_MDL_WORKSPACE_MEM_OPTIMIZE, &mem_optimize_mode, + sizeof(size_t))); + + MGB_ATLAS_CHECK(aclmdlLoadWithConfig(config_handle, &m_model_id)); + MGB_ATLAS_CHECK(aclmdlDestroyConfigHandle(config_handle)); + m_model_desc = aclmdlCreateDesc(); MGB_ATLAS_CHECK(aclmdlGetDesc(m_model_desc, m_model_id)); m_is_model_holder = true;