提交 d68d4d1d 编写于 作者: M Megvii Engine Team

perf(atlas): use async d2d

GitOrigin-RevId: 55914631cb63bc1057b2f2a124d8168b3e1e29cb
上级 d8ac6c70
...@@ -41,26 +41,22 @@ AtlasComputingContext::~AtlasComputingContext() { ...@@ -41,26 +41,22 @@ AtlasComputingContext::~AtlasComputingContext() {
void AtlasComputingContext::memcpy(void* dst, const void* src, void AtlasComputingContext::memcpy(void* dst, const void* src,
size_t size_in_bytes, size_t size_in_bytes,
megcoreMemcpyKind_t kind) { megcoreMemcpyKind_t kind) {
aclrtMemcpyKind atlas_kind;
switch (kind) { switch (kind) {
case megcoreMemcpyDeviceToHost: case megcoreMemcpyDeviceToHost:
atlas_kind = ACL_MEMCPY_DEVICE_TO_HOST; acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_HOST));
break; break;
case megcoreMemcpyHostToDevice: case megcoreMemcpyHostToDevice:
atlas_kind = ACL_MEMCPY_HOST_TO_DEVICE; acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes,
ACL_MEMCPY_HOST_TO_DEVICE));
break; break;
case megcoreMemcpyDeviceToDevice: case megcoreMemcpyDeviceToDevice:
atlas_kind = ACL_MEMCPY_DEVICE_TO_DEVICE; acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE, m_ctx.stream));
break; break;
default: default:
megdnn_throw("bad atlas memcpy kind"); megdnn_throw("bad atlas memcpy kind");
} }
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
atlas_kind, m_ctx.stream));
#else
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, atlas_kind));
#endif
} }
void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
...@@ -69,11 +65,7 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { ...@@ -69,11 +65,7 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
} }
void AtlasComputingContext::synchronize() { void AtlasComputingContext::synchronize() {
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtSynchronizeStream(m_ctx.stream)); acl_check(aclrtSynchronizeStream(m_ctx.stream));
#else
return;
#endif
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -230,10 +230,10 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest, ...@@ -230,10 +230,10 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
auto&& src_env = m_env.atlas_env(); auto&& src_env = m_env.atlas_env();
activate(); activate();
if (dst_env.device == src_env.device) { if (dst_env.device == src_env.device) {
#if MGB_USE_ATLAS_ASYNC_API #if 1
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size, MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE, ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream)); dst_env.stream));
#else #else
MGB_ATLAS_CHECK(aclrtMemcpy(dest, size, src, size, MGB_ATLAS_CHECK(aclrtMemcpy(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE)); ACL_MEMCPY_DEVICE_TO_DEVICE));
......
...@@ -361,6 +361,7 @@ void AtlasRuntimeOpr::scn_do_execute() { ...@@ -361,6 +361,7 @@ void AtlasRuntimeOpr::scn_do_execute() {
i, output(i)->cname()); i, output(i)->cname());
aclmdlAddDatasetBuffer(model_outputs, output_db); aclmdlAddDatasetBuffer(model_outputs, output_db);
} }
MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs)); MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));
for (size_t i = 0; i < nr_inputs; ++i) { for (size_t i = 0; i < nr_inputs; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册