提交 b8ddca4c 编写于 作者: M Megvii Engine Team

fix(atlas): add MGB_USE_ATLAS_ASYNC_API to enable async api

GitOrigin-RevId: ab821f4966f67b5a4f67a68bf1e0dfe14e59d8bc
上级 6cab1dd7
......@@ -55,8 +55,12 @@ void AtlasComputingContext::memcpy(void* dst, const void* src,
default:
megdnn_throw("bad atlas memcpy kind");
}
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes,
atlas_kind, m_ctx.stream));
#else
acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, atlas_kind));
#endif
}
void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
......@@ -65,7 +69,11 @@ void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) {
}
void AtlasComputingContext::synchronize() {
#if MGB_USE_ATLAS_ASYNC_API
acl_check(aclrtSynchronizeStream(m_ctx.stream));
#else
return;
#endif
}
// vim: syntax=cpp.doxygen
......@@ -104,9 +104,14 @@ public:
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
activate();
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemcpyAsync(host_ptr, size, device_ptr, size,
ACL_MEMCPY_DEVICE_TO_HOST,
m_env.atlas_env().stream));
#else
MGB_ATLAS_CHECK(aclrtMemcpy(host_ptr, size, device_ptr, size,
ACL_MEMCPY_DEVICE_TO_HOST));
#endif
}
void copy_to_device(void* device_ptr, const void* host_ptr,
......@@ -225,9 +230,14 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
auto&& src_env = m_env.atlas_env();
activate();
if (dst_env.device == src_env.device) {
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
dst_env.stream));
#else
MGB_ATLAS_CHECK(aclrtMemcpy(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_DEVICE));
#endif
} else {
mgb_throw(MegBrainError,
"Atlas does not support peer copy between differents "
......@@ -239,12 +249,18 @@ void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
mgb_assert(dest_impl->env().property().type == DeviceType::CPU,
"cuda peer_copy_to only implemented for CPU");
auto copy = [this, dest, src, size]() {
auto stream = m_env.atlas_env().stream;
m_env.atlas_env().activate();
#if MGB_USE_ATLAS_ASYNC_API
auto stream = m_env.atlas_env().stream;
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size,
ACL_MEMCPY_DEVICE_TO_HOST,
m_env.atlas_env().stream));
MGB_ATLAS_CHECK(aclrtSynchronizeStream(stream));
#else
MGB_ATLAS_CHECK(
aclrtMemcpy(dest, size, src, size, ACL_MEMCPY_DEVICE_TO_HOST));
#endif
};
dest_impl->env().cpu_env().dispatch(copy);
......
......@@ -614,8 +614,12 @@ void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) {
#endif
#if MGB_ATLAS
case CompNode::DeviceType::ATLAS:
#if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK(aclrtMemsetAsync(ptr, -1, val, size,
env.atlas_env().stream));
#else
MGB_ATLAS_CHECK(aclrtMemset(ptr, -1, val, size));
#endif
break;
#endif
case CompNode::DeviceType::CPU: {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册