diff --git a/src/core/include/megbrain/comp_node_env.h b/src/core/include/megbrain/comp_node_env.h index 9e46450fc598895df77a54e0b43403468e03f993..ee5817f335a12f88cfc2e6d59e1218c3a42d8053 100644 --- a/src/core/include/megbrain/comp_node_env.h +++ b/src/core/include/megbrain/comp_node_env.h @@ -45,7 +45,7 @@ #endif //MGB_CUDA #if MGB_ATLAS -#include "acl/acl.h" +#include "megcore_atlas.h" #include #if MGB_ENABLE_LOGGING @@ -378,7 +378,16 @@ public: void activate() const { init(); - MGB_ATLAS_CHECK(aclrtSetDevice(device)); + int32_t device_id = -1; + auto err = aclrtGetDevice(&device_id); + if (err == ACL_ERROR_INVALID_DEVICE || device != device_id) { + MGB_ATLAS_CHECK(aclrtSetDevice(device)); + } else { + MGB_ATLAS_CHECK(err); + mgb_assert(err == ACL_ERROR_NONE, + "Failed to invoke aclrtGetDevice, get %s(%d)", + megcore::atlas::get_error_str(err), err); + } } };