diff --git a/dnn/src/atlas/megcore/computing_context.cpp b/dnn/src/atlas/megcore/computing_context.cpp index 715dcc77dba5fb86c47483c504b348a5413bc698..69b4d3b620cc546666dcd0ca63e43a13e60dac14 100644 --- a/dnn/src/atlas/megcore/computing_context.cpp +++ b/dnn/src/atlas/megcore/computing_context.cpp @@ -41,26 +41,22 @@ AtlasComputingContext::~AtlasComputingContext() { void AtlasComputingContext::memcpy(void* dst, const void* src, size_t size_in_bytes, megcoreMemcpyKind_t kind) { - aclrtMemcpyKind atlas_kind; switch (kind) { case megcoreMemcpyDeviceToHost: - atlas_kind = ACL_MEMCPY_DEVICE_TO_HOST; + acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, + ACL_MEMCPY_DEVICE_TO_HOST)); break; case megcoreMemcpyHostToDevice: - atlas_kind = ACL_MEMCPY_HOST_TO_DEVICE; + acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, + ACL_MEMCPY_HOST_TO_DEVICE)); break; case megcoreMemcpyDeviceToDevice: - atlas_kind = ACL_MEMCPY_DEVICE_TO_DEVICE; + acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, m_ctx.stream)); break; default: megdnn_throw("bad atlas memcpy kind"); } -#if MGB_USE_ATLAS_ASYNC_API - acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes, - atlas_kind, m_ctx.stream)); -#else - acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, atlas_kind)); -#endif } void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { @@ -69,11 +65,7 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { } void AtlasComputingContext::synchronize() { -#if MGB_USE_ATLAS_ASYNC_API acl_check(aclrtSynchronizeStream(m_ctx.stream)); -#else - return; -#endif } // vim: syntax=cpp.doxygen diff --git a/src/core/impl/comp_node/atlas/comp_node.cpp b/src/core/impl/comp_node/atlas/comp_node.cpp index 4b7a03d5b7d08a76ecc251d9addd4b450cd4d9b7..1f64bf0eb9990859478a2bee6142e034cc8cb492 100644 --- a/src/core/impl/comp_node/atlas/comp_node.cpp +++ b/src/core/impl/comp_node/atlas/comp_node.cpp @@ -230,10 +230,10 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest, auto&& src_env = m_env.atlas_env(); activate(); if (dst_env.device == src_env.device) { -#if MGB_USE_ATLAS_ASYNC_API - MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size, - ACL_MEMCPY_DEVICE_TO_DEVICE, - dst_env.stream)); +#if 1 + MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size, + ACL_MEMCPY_DEVICE_TO_DEVICE, + dst_env.stream)); #else MGB_ATLAS_CHECK(aclrtMemcpy(dest, size, src, size, ACL_MEMCPY_DEVICE_TO_DEVICE)); diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index 7853545a2ec9f28fc27fd9e814637ca04ec38246..051a2c8f2fd6692ecd8bf288f4561a531dbac744 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -361,6 +361,7 @@ void AtlasRuntimeOpr::scn_do_execute() { i, output(i)->cname()); aclmdlAddDatasetBuffer(model_outputs, output_db); } + MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs)); for (size_t i = 0; i < nr_inputs; ++i) {