提交 933dd9a4 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/distributed): add cuda env check before forked thread

style(core/comp_node): reformat code

GitOrigin-RevId: 372452a8eb9e84a2e82d466074f80f78d70531e8
上级 2a541961
...@@ -165,6 +165,18 @@ def _get_device_count_worker(queue, device_type): ...@@ -165,6 +165,18 @@ def _get_device_count_worker(queue, device_type):
queue.put(num) queue.put(num)
def _check_device_initialized(device_type: str):
try:
test = Tensor(1, device=device_type)
inited = False
del test
except:
inited = True
errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
if inited:
raise RuntimeError(errmsg)
def get_device_count_by_fork(device_type: str): def get_device_count_by_fork(device_type: str):
""" """
Get device count in fork thread. Get device count in fork thread.
......
...@@ -15,7 +15,7 @@ from .. import _exit ...@@ -15,7 +15,7 @@ from .. import _exit
from ..core._imperative_rt.core2 import full_sync from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger from ..logger import get_logger
from .group import group_barrier, init_process_group from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import _check_device_initialized, get_device_count_by_fork
from .server import Client, Server from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
...@@ -37,6 +37,7 @@ def _run_wrapped( ...@@ -37,6 +37,7 @@ def _run_wrapped(
queue: mp.Queue, queue: mp.Queue,
): ):
"""Init distributed process group and run wrapped function.""" """Init distributed process group and run wrapped function."""
_check_device_initialized(device_type)
init_process_group( init_process_group(
master_ip=master_ip, master_ip=master_ip,
port=port, port=port,
......
...@@ -246,3 +246,16 @@ def test_io_remote(shape): ...@@ -246,3 +246,16 @@ def test_io_remote(shape):
val = np.random.random_sample(shape).astype("float32") val = np.random.random_sample(shape).astype("float32")
worker(val, shape) worker(val, shape)
@pytest.mark.require_ngpu(2)
def test_cuda_init_before_fork():
a = mge.tensor(1, device="gpu0")
@dist.launcher(n_gpus=2)
def worker():
a += 1
b = mge.tensor(2)
with pytest.raises(AssertionError):
worker()
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "./comp_node.h" #include "./comp_node.h"
...@@ -21,8 +22,8 @@ using namespace mgb; ...@@ -21,8 +22,8 @@ using namespace mgb;
#include "megbrain/comp_node/alloc.h" #include "megbrain/comp_node/alloc.h"
#include <cstdio>
#include <cctype> #include <cctype>
#include <cstdio>
#include <thread> #include <thread>
...@@ -31,26 +32,23 @@ using namespace mgb; ...@@ -31,26 +32,23 @@ using namespace mgb;
using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; using CudaCompNodeImpl = CudaCompNode::CompNodeImpl;
namespace { namespace {
size_t get_min_system_memory(size_t available) { size_t get_min_system_memory(size_t available) {
if (available < (1u << 31)) { if (available < (1u << 31)) {
// 225MiB // 225MiB
return 225 * 1024 * 1024; return 225 * 1024 * 1024;
} else { } else {
// max(300 MiB, 0.05 * available) // max(300 MiB, 0.05 * available)
return std::max<size_t>(300 * 1024 * 1024, available / 20); return std::max<size_t>(300 * 1024 * 1024, available / 20);
}
} }
using CudaHostFunc = megdnn::thin_function<void()>; }
void CUDART_CB cuda_host_func_caller(void* ud) { using CudaHostFunc = megdnn::thin_function<void()>;
mgb_assert(ud); void CUDART_CB cuda_host_func_caller(void* ud) {
CudaHostFunc* func_ptr = reinterpret_cast<CudaHostFunc*>(ud); mgb_assert(ud);
MGB_TRY { CudaHostFunc* func_ptr = reinterpret_cast<CudaHostFunc*>(ud);
(*func_ptr)(); MGB_TRY { (*func_ptr)(); }
} MGB_FINALLY( MGB_FINALLY(delete func_ptr;);
delete func_ptr; }
); } // anonymous namespace
}
} // anonymous namespace
namespace mgb { namespace mgb {
namespace mem_alloc { namespace mem_alloc {
...@@ -103,7 +101,8 @@ class CudaHostAllocator : public RawAllocator { ...@@ -103,7 +101,8 @@ class CudaHostAllocator : public RawAllocator {
public: public:
void* alloc(size_t size) override { void* alloc(size_t size) override {
void* addr; void* addr;
cudaError_t cuda_error = cudaHostAlloc(&addr, size, cudaHostAllocDefault); cudaError_t cuda_error =
cudaHostAlloc(&addr, size, cudaHostAllocDefault);
if (cuda_error == cudaSuccess) { if (cuda_error == cudaSuccess) {
mgb_assert(addr); mgb_assert(addr);
return addr; return addr;
...@@ -162,7 +161,7 @@ std::unique_ptr<DevMemAlloc> DevMemAlloc::make_cuda_alloc() { ...@@ -162,7 +161,7 @@ std::unique_ptr<DevMemAlloc> DevMemAlloc::make_cuda_alloc() {
} // namespace mgb } // namespace mgb
/* ===================== CudaCompNodeImpl ===================== */ /* ===================== CudaCompNodeImpl ===================== */
class CudaCompNode::CompNodeImpl final: public CompNode::Impl { class CudaCompNode::CompNodeImpl final : public CompNode::Impl {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
friend class EventImpl; friend class EventImpl;
...@@ -170,7 +169,7 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { ...@@ -170,7 +169,7 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
struct DeviceInfo; struct DeviceInfo;
struct StaticData; struct StaticData;
static StaticData *sd; static StaticData* sd;
static Spinlock sd_mtx; static Spinlock sd_mtx;
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
std::mutex m_update_mem; std::mutex m_update_mem;
...@@ -180,17 +179,15 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { ...@@ -180,17 +179,15 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
//! failed //! failed
bool m_initialized = false; bool m_initialized = false;
Locator m_locator, m_locator_logical; Locator m_locator, m_locator_logical;
mem_alloc::StreamMemAlloc *m_mem_alloc; mem_alloc::StreamMemAlloc* m_mem_alloc;
DeviceInfo *m_device_info; DeviceInfo* m_device_info;
std::unique_ptr<Event> m_sync_event; std::unique_ptr<Event> m_sync_event;
Spinlock m_sync_event_mtx; Spinlock m_sync_event_mtx;
void activate() { void activate() { m_env.cuda_env().activate(); }
m_env.cuda_env().activate();
}
void init(const Locator &locator, const Locator &locator_logical); void init(const Locator& locator, const Locator& locator_logical);
void fini(); void fini();
//! return whether global finalized, and print warning in such case //! return whether global finalized, and print warning in such case
...@@ -207,117 +204,111 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { ...@@ -207,117 +204,111 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
static_cast<CompNodeImpl*>(self)->free_host(ptr); static_cast<CompNodeImpl*>(self)->free_host(ptr);
} }
public:
CompNodeImpl() : Impl(static_free_device, static_free_host) {}
public: void* alloc_device(size_t size) override {
CompNodeImpl() : Impl(static_free_device, static_free_host) {} activate();
void* alloc_device(size_t size) override {
activate();
#if MGB_BUILD_SLIM_SERVING #if MGB_BUILD_SLIM_SERVING
return m_mem_alloc->alloc(size); return m_mem_alloc->alloc(size);
#else #else
void* ptr = m_mem_alloc->alloc(size); void* ptr = m_mem_alloc->alloc(size);
{ {
MGB_LOCK_GUARD(m_update_mem); MGB_LOCK_GUARD(m_update_mem);
ptr2size[ptr] = size; ptr2size[ptr] = size;
m_used_mem += size; m_used_mem += size;
}
return ptr;
#endif
} }
return ptr;
#endif
}
void free_device(void *ptr); void free_device(void* ptr);
void *alloc_host(size_t size) override; void* alloc_host(size_t size) override;
void free_host(void *ptr); void free_host(void* ptr);
void copy_to_host(void *host_ptr, void copy_to_host(void* host_ptr, const void* device_ptr,
const void *device_ptr, size_t size) override { size_t size) override {
activate(); activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(host_ptr, device_ptr, size, MGB_CUDA_CHECK(cudaMemcpyAsync(host_ptr, device_ptr, size,
cudaMemcpyDeviceToHost, m_env.cuda_env().stream)); cudaMemcpyDeviceToHost,
} m_env.cuda_env().stream));
}
void copy_to_device(void *device_ptr, void copy_to_device(void* device_ptr, const void* host_ptr,
const void *host_ptr, size_t size) override { size_t size) override {
activate(); activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(device_ptr, host_ptr, size, MGB_CUDA_CHECK(cudaMemcpyAsync(device_ptr, host_ptr, size,
cudaMemcpyHostToDevice, m_env.cuda_env().stream)); cudaMemcpyHostToDevice,
} m_env.cuda_env().stream));
}
void peer_copy_to( void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
Impl *dest_impl, void *dest, size_t size) override;
const void *src, size_t size) override;
size_t get_mem_addr_alignment() override { size_t get_mem_addr_alignment() override {
return m_env.property().mem_alignment; return m_env.property().mem_alignment;
} }
std::unique_ptr<Event> create_event(size_t flags) override; std::unique_ptr<Event> create_event(size_t flags) override;
void sync() override; void sync() override;
MemNode mem_node() override; MemNode mem_node() override;
std::pair<size_t, size_t> get_mem_status_bytes() override { std::pair<size_t, size_t> get_mem_status_bytes() override {
// explicitly call cuda_env() to ensure async init is finished // explicitly call cuda_env() to ensure async init is finished
m_env.cuda_env().activate(); m_env.cuda_env().activate();
size_t tot, free; size_t tot, free;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
free += m_mem_alloc->get_free_memory_dev().tot; free += m_mem_alloc->get_free_memory_dev().tot;
return {tot, free}; return {tot, free};
} }
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override { std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr,
return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr); size_t end_ptr) override {
} return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr);
}
#endif #endif
Locator locator() override { Locator locator() override { return m_locator; }
return m_locator;
}
Locator locator_logical() override { Locator locator_logical() override { return m_locator_logical; }
return m_locator_logical;
}
void add_callback(CudaHostFunc&& cb) override { void add_callback(CudaHostFunc&& cb) override {
#if CUDART_VERSION >= 10000 #if CUDART_VERSION >= 10000
activate(); activate();
CudaHostFunc* func_ptr = new CudaHostFunc(std::move(cb)); CudaHostFunc* func_ptr = new CudaHostFunc(std::move(cb));
MGB_TRY { MGB_TRY {
MGB_CUDA_CHECK(cudaLaunchHostFunc(m_env.cuda_env().stream, MGB_CUDA_CHECK(cudaLaunchHostFunc(m_env.cuda_env().stream,
cuda_host_func_caller, static_cast<void*>(func_ptr))); cuda_host_func_caller,
} MGB_CATCH(..., { static_cast<void*>(func_ptr)));
delete func_ptr; }
throw; MGB_CATCH(..., {
}); delete func_ptr;
throw;
});
#else #else
MGB_MARK_USED_VAR(cb); MGB_MARK_USED_VAR(cb);
MGB_MARK_USED_VAR(cuda_host_func_caller); MGB_MARK_USED_VAR(cuda_host_func_caller);
mgb_throw( mgb_throw(MegBrainError,
MegBrainError, "add_callback only support in cuda10.0 and later version");
"add_callback only support in cuda10.0 and later version");
#endif #endif
} }
uint64_t get_uid() override { uint64_t get_uid() override { return m_uid; }
return m_uid;
}
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
size_t get_used_memory() override { size_t get_used_memory() override { return m_used_mem; }
return m_used_mem;
}
#endif #endif
private: private:
uint64_t m_uid; uint64_t m_uid;
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
std::unordered_map<void*, size_t> ptr2size; std::unordered_map<void*, size_t> ptr2size;
size_t m_used_mem = 0; size_t m_used_mem = 0;
#endif #endif
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CudaCompNode::CompNodeImpl); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CudaCompNode::CompNodeImpl);
...@@ -326,15 +317,11 @@ struct CudaCompNodeImpl::DeviceInfo { ...@@ -326,15 +317,11 @@ struct CudaCompNodeImpl::DeviceInfo {
int dev_num = -1; int dev_num = -1;
std::unique_ptr<mem_alloc::DevMemAlloc> mem_alloc; std::unique_ptr<mem_alloc::DevMemAlloc> mem_alloc;
bool init_done() const { bool init_done() const { return mem_alloc.get(); }
return mem_alloc.get();
}
void init(const CompNodeEnv &env); void init(const CompNodeEnv& env);
void fini() { void fini() { mem_alloc.reset(); }
mem_alloc.reset();
}
}; };
struct CudaCompNodeImpl::StaticData { struct CudaCompNodeImpl::StaticData {
...@@ -347,21 +334,21 @@ struct CudaCompNodeImpl::StaticData { ...@@ -347,21 +334,21 @@ struct CudaCompNodeImpl::StaticData {
std::unique_ptr<mem_alloc::SimpleCachingAlloc> host_alloc; std::unique_ptr<mem_alloc::SimpleCachingAlloc> host_alloc;
CudaCompNode::CompNodeImpl node[MAX_NR_COMP_NODE]; CudaCompNode::CompNodeImpl node[MAX_NR_COMP_NODE];
DeviceInfo dev_info[MAX_NR_DEVICE]; DeviceInfo dev_info[MAX_NR_DEVICE];
int nr_node = 0, //!< number of loaded node[] int nr_node = 0, //!< number of loaded node[]
nr_dev_used = 0; //!< number of used dev_info[] nr_dev_used = 0; //!< number of used dev_info[]
StaticData() : host_alloc( StaticData()
mem_alloc::SimpleCachingAlloc::make( : host_alloc(mem_alloc::SimpleCachingAlloc::make(
std::make_unique<mem_alloc::CudaHostAllocator>())) { std::make_unique<mem_alloc::CudaHostAllocator>())) {
prealloc_config.max_overhead = 0; prealloc_config.max_overhead = 0;
prealloc_config.alignment = 1; prealloc_config.alignment = 1;
host_alloc->alignment(1); host_alloc->alignment(1);
} }
~StaticData() { ~StaticData() {
for (int i = 0; i < nr_node; ++ i) for (int i = 0; i < nr_node; ++i)
node[i].fini(); node[i].fini();
for (int i = 0; i < nr_dev_used; ++ i) for (int i = 0; i < nr_dev_used; ++i)
dev_info[i].fini(); dev_info[i].fini();
} }
...@@ -382,21 +369,21 @@ struct CudaCompNodeImpl::StaticData { ...@@ -382,21 +369,21 @@ struct CudaCompNodeImpl::StaticData {
CudaCompNodeImpl::StaticData* CudaCompNodeImpl::sd = nullptr; CudaCompNodeImpl::StaticData* CudaCompNodeImpl::sd = nullptr;
Spinlock CudaCompNodeImpl::sd_mtx; Spinlock CudaCompNodeImpl::sd_mtx;
void CudaCompNodeImpl::init( void CudaCompNodeImpl::init(const Locator& locator,
const Locator &locator, const Locator &locator_logical) { const Locator& locator_logical) {
m_locator = locator; m_locator = locator;
m_locator_logical = locator_logical; m_locator_logical = locator_logical;
m_initialized = true; m_initialized = true;
#if defined(__linux__) || defined(TARGET_OS_MAC) #if defined(__linux__) || defined(TARGET_OS_MAC)
FILE *fp; FILE* fp;
fp = fopen("/dev/urandom", "r"); fp = fopen("/dev/urandom", "r");
mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1); mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1);
fclose(fp); fclose(fp);
#else #else
m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>( m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch() std::chrono::system_clock::now().time_since_epoch())
).count(); .count();
#endif #endif
auto on_succ = [this](cudaStream_t stream) { auto on_succ = [this](cudaStream_t stream) {
...@@ -404,8 +391,8 @@ void CudaCompNodeImpl::init( ...@@ -404,8 +391,8 @@ void CudaCompNodeImpl::init(
log_comp_node_created(locator, m_locator_logical); log_comp_node_created(locator, m_locator_logical);
MGB_LOCK_GUARD(sd->mtx); MGB_LOCK_GUARD(sd->mtx);
DeviceInfo *dev_info = nullptr; DeviceInfo* dev_info = nullptr;
for (int i = 0; i < sd->nr_dev_used; ++ i) { for (int i = 0; i < sd->nr_dev_used; ++i) {
if (sd->dev_info[i].dev_num == locator.device) { if (sd->dev_info[i].dev_num == locator.device) {
dev_info = &sd->dev_info[i]; dev_info = &sd->dev_info[i];
break; break;
...@@ -416,7 +403,7 @@ void CudaCompNodeImpl::init( ...@@ -416,7 +403,7 @@ void CudaCompNodeImpl::init(
dev_info = &sd->dev_info[sd->nr_dev_used]; dev_info = &sd->dev_info[sd->nr_dev_used];
dev_info->init(m_env); dev_info->init(m_env);
// note: add nr_dev_used only after init succeeds // note: add nr_dev_used only after init succeeds
++ sd->nr_dev_used; ++sd->nr_dev_used;
} }
m_device_info = dev_info; m_device_info = dev_info;
m_mem_alloc = m_mem_alloc =
...@@ -428,9 +415,8 @@ void CudaCompNodeImpl::init( ...@@ -428,9 +415,8 @@ void CudaCompNodeImpl::init(
m_initialized = false; m_initialized = false;
}; };
m_env.init_cuda_async( m_env.init_cuda_async(locator.device, make_comp_node_from_impl(this),
locator.device, make_comp_node_from_impl(this), {on_succ, on_error});
{on_succ, on_error});
} }
void CudaCompNodeImpl::fini() { void CudaCompNodeImpl::fini() {
...@@ -444,7 +430,7 @@ void CudaCompNodeImpl::fini() { ...@@ -444,7 +430,7 @@ void CudaCompNodeImpl::fini() {
m_initialized = false; m_initialized = false;
} }
void CudaCompNodeImpl::free_device(void *ptr) { void CudaCompNodeImpl::free_device(void* ptr) {
if (check_global_finalized()) if (check_global_finalized())
return; return;
...@@ -452,13 +438,13 @@ void CudaCompNodeImpl::free_device(void *ptr) { ...@@ -452,13 +438,13 @@ void CudaCompNodeImpl::free_device(void *ptr) {
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
{ {
MGB_LOCK_GUARD(m_update_mem); MGB_LOCK_GUARD(m_update_mem);
mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!", ptr); mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!",
ptr);
m_used_mem -= ptr2size.at(ptr); m_used_mem -= ptr2size.at(ptr);
ptr2size.erase(ptr); ptr2size.erase(ptr);
} }
#endif #endif
m_mem_alloc->free(ptr); m_mem_alloc->free(ptr);
} }
void* CudaCompNodeImpl::alloc_host(size_t size) { void* CudaCompNodeImpl::alloc_host(size_t size) {
...@@ -468,38 +454,37 @@ void* CudaCompNodeImpl::alloc_host(size_t size) { ...@@ -468,38 +454,37 @@ void* CudaCompNodeImpl::alloc_host(size_t size) {
} }
void CudaCompNodeImpl::free_host(void* ptr) { void CudaCompNodeImpl::free_host(void* ptr) {
if (check_global_finalized()) return; if (check_global_finalized())
return;
sd->host_alloc->free(ptr); sd->host_alloc->free(ptr);
} }
void CudaCompNodeImpl::peer_copy_to( void CudaCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
Impl *dest_impl, void *dest, const void *src, size_t size) { const void* src, size_t size) {
if (dest_impl->same_type<CudaCompNodeImpl>()) { if (dest_impl->same_type<CudaCompNodeImpl>()) {
auto &&dst_env = static_cast<CudaCompNodeImpl*>( auto&& dst_env =
dest_impl)->m_env.cuda_env(); static_cast<CudaCompNodeImpl*>(dest_impl)->m_env.cuda_env();
auto &&src_env = m_env.cuda_env(); auto&& src_env = m_env.cuda_env();
activate(); activate();
if (dst_env.device == src_env.device) { if (dst_env.device == src_env.device) {
MGB_CUDA_CHECK(cudaMemcpyAsync(dest, src, size, MGB_CUDA_CHECK(cudaMemcpyAsync(
cudaMemcpyDeviceToDevice, dest, src, size, cudaMemcpyDeviceToDevice, dst_env.stream));
dst_env.stream));
} else { } else {
enable_peer_access(src_env.device, dst_env.device); enable_peer_access(src_env.device, dst_env.device);
enable_peer_access(dst_env.device, src_env.device); enable_peer_access(dst_env.device, src_env.device);
MGB_CUDA_CHECK(cudaMemcpyPeerAsync( MGB_CUDA_CHECK(cudaMemcpyPeerAsync(dest, dst_env.device, src,
dest, dst_env.device, src_env.device, size,
src, src_env.device, size, dst_env.stream));
dst_env.stream));
} }
return; return;
} }
mgb_assert(dest_impl->env().property().type == DeviceType::CPU, mgb_assert(dest_impl->env().property().type == DeviceType::CPU,
"cuda peer_copy_to only implemented for CPU"); "cuda peer_copy_to only implemented for CPU");
auto copy = [this, dest, src, size]() { auto copy = [this, dest, src, size]() {
auto stream = m_env.cuda_env().stream; auto stream = m_env.cuda_env().stream;
m_env.cuda_env().activate(); m_env.cuda_env().activate();
MGB_CUDA_CHECK(cudaMemcpyAsync( MGB_CUDA_CHECK(cudaMemcpyAsync(dest, src, size, cudaMemcpyDeviceToHost,
dest, src, size, cudaMemcpyDeviceToHost, stream)); stream));
MGB_CUDA_CHECK(cudaStreamSynchronize(stream)); MGB_CUDA_CHECK(cudaStreamSynchronize(stream));
}; };
dest_impl->env().cpu_env().dispatch(copy); dest_impl->env().cpu_env().dispatch(copy);
...@@ -514,13 +499,13 @@ MemNode CudaCompNodeImpl::mem_node() { ...@@ -514,13 +499,13 @@ MemNode CudaCompNodeImpl::mem_node() {
void CudaCompNodeImpl::sync() { void CudaCompNodeImpl::sync() {
activate(); activate();
// do not use MGB_CUDA_CHECK(cudaStreamSynchronize(m_env->stream)) since other // do not use MGB_CUDA_CHECK(cudaStreamSynchronize(m_env->stream)) since
// threads may be adding operations into the stream, and we only care about // other threads may be adding operations into the stream, and we only care
// previous operations in current thread. However docs of // about previous operations in current thread. However docs of
// cudaStreamSynchronize did not describe details of such condition, so we // cudaStreamSynchronize did not describe details of such condition, so we
// use manual event implementation // use manual event implementation
Event *event; Event* event;
{ {
MGB_LOCK_GUARD(m_sync_event_mtx); MGB_LOCK_GUARD(m_sync_event_mtx);
if (!m_sync_event) if (!m_sync_event)
...@@ -532,8 +517,8 @@ void CudaCompNodeImpl::sync() { ...@@ -532,8 +517,8 @@ void CudaCompNodeImpl::sync() {
} }
void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) { void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) {
static bool already_enabled[ static bool already_enabled[StaticData::MAX_NR_DEVICE]
StaticData::MAX_NR_DEVICE][StaticData::MAX_NR_DEVICE]; [StaticData::MAX_NR_DEVICE];
if (already_enabled[dev0][dev1]) if (already_enabled[dev0][dev1])
return; return;
...@@ -550,7 +535,8 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) { ...@@ -550,7 +535,8 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) {
auto err = cudaDeviceEnablePeerAccess(dev1, 0); auto err = cudaDeviceEnablePeerAccess(dev1, 0);
if (err != cudaSuccess) { if (err != cudaSuccess) {
mgb_log_error("failed to enable peer access from %d to %d: %s(%d)", mgb_log_error("failed to enable peer access from %d to %d: %s(%d)",
dev0, dev1, cudaGetErrorString(err), static_cast<int>(err)); dev0, dev1, cudaGetErrorString(err),
static_cast<int>(err));
cudaGetLastError(); cudaGetLastError();
} }
} }
...@@ -563,33 +549,29 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) { ...@@ -563,33 +549,29 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) {
MGB_CUDA_CHECK(cudaMalloc(&dp0, sizeof(int))); MGB_CUDA_CHECK(cudaMalloc(&dp0, sizeof(int)));
MGB_CUDA_CHECK(cudaSetDevice(dev1)); MGB_CUDA_CHECK(cudaSetDevice(dev1));
MGB_CUDA_CHECK(cudaMalloc(&dp1, sizeof(int))); MGB_CUDA_CHECK(cudaMalloc(&dp1, sizeof(int)));
MGB_CUDA_CHECK(cudaMemcpy(dp0, &v0, sizeof(int), MGB_CUDA_CHECK(cudaMemcpy(dp0, &v0, sizeof(int), cudaMemcpyHostToDevice));
cudaMemcpyHostToDevice)); MGB_CUDA_CHECK(cudaMemcpy(dp1, &v1, sizeof(int), cudaMemcpyHostToDevice));
MGB_CUDA_CHECK(cudaMemcpy(dp1, &v1, sizeof(int),
cudaMemcpyHostToDevice));
MGB_CUDA_CHECK(cudaMemcpyPeer(dp1, dev1, dp0, dev0, sizeof(int))); MGB_CUDA_CHECK(cudaMemcpyPeer(dp1, dev1, dp0, dev0, sizeof(int)));
int get = 0; int get = 0;
MGB_CUDA_CHECK(cudaMemcpy(&get, dp1, sizeof(int), MGB_CUDA_CHECK(cudaMemcpy(&get, dp1, sizeof(int), cudaMemcpyDeviceToHost));
cudaMemcpyDeviceToHost));
mgb_throw_if(get != 1, CudaError, mgb_throw_if(get != 1, CudaError,
"P2P copy (%d => %d) check failed; consider disabling " "P2P copy (%d => %d) check failed; consider disabling "
"Access Control Services(ACS) for the PCI device", "Access Control Services(ACS) for the PCI device",
dev0, dev1); dev0, dev1);
already_enabled[dev0][dev1] = true; already_enabled[dev0][dev1] = true;
} }
/* ===================== CudaCompNodeImpl::DeviceInfo ===================== */ /* ===================== CudaCompNodeImpl::DeviceInfo ===================== */
void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) { void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv& env) {
mgb_assert(!mem_alloc); mgb_assert(!mem_alloc);
#if 0 #if 0
// forward cudaMalloc // forward cudaMalloc
mem_alloc = mem_alloc::DevMemAlloc::make_cuda_alloc(); mem_alloc = mem_alloc::DevMemAlloc::make_cuda_alloc();
#else #else
auto &&cuenv = env.cuda_env(); auto&& cuenv = env.cuda_env();
cuenv.activate(); cuenv.activate();
dev_num = cuenv.device; dev_num = cuenv.device;
auto reserve_size = StaticData::get_mem_reserve_size(); auto reserve_size = StaticData::get_mem_reserve_size();
...@@ -600,9 +582,10 @@ void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) { ...@@ -600,9 +582,10 @@ void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) {
mem_alloc->prealloc_config(sd->prealloc_config); mem_alloc->prealloc_config(sd->prealloc_config);
auto align = env.property().mem_alignment; auto align = env.property().mem_alignment;
mem_alloc->alignment(align); mem_alloc->alignment(align);
mgb_log_debug("cuda: gpu%d: name=`%s' dyn_mem_reserve=%.2fMiB alignment=0x%zx", mgb_log_debug(
dev_num, cuenv.device_prop.name, "cuda: gpu%d: name=`%s' dyn_mem_reserve=%.2fMiB alignment=0x%zx",
reserve_size / 1024.0 / 1024, align); dev_num, cuenv.device_prop.name, reserve_size / 1024.0 / 1024,
align);
#endif #endif
} }
...@@ -631,14 +614,14 @@ bool CudaCompNodeImpl::check_global_finalized() { ...@@ -631,14 +614,14 @@ bool CudaCompNodeImpl::check_global_finalized() {
/* ===================== EventImpl ===================== */ /* ===================== EventImpl ===================== */
class CudaCompNode::EventImpl final: public EventImplHelper { class CudaCompNode::EventImpl final : public EventImplHelper {
bool m_init_finished = false; bool m_init_finished = false;
CudaCompNodeImpl * const m_comp_node_impl; CudaCompNodeImpl* const m_comp_node_impl;
cudaEvent_t m_cuda_event; cudaEvent_t m_cuda_event;
void do_record() override { void do_record() override {
m_comp_node_impl->activate(); m_comp_node_impl->activate();
auto &&env = m_comp_node_impl->m_env.cuda_env(); auto&& env = m_comp_node_impl->m_env.cuda_env();
MGB_CUDA_CHECK(cudaEventRecord(m_cuda_event, env.stream)); MGB_CUDA_CHECK(cudaEventRecord(m_cuda_event, env.stream));
} }
...@@ -649,56 +632,51 @@ class CudaCompNode::EventImpl final: public EventImplHelper { ...@@ -649,56 +632,51 @@ class CudaCompNode::EventImpl final: public EventImplHelper {
return true; return true;
if (err == cudaErrorNotReady) if (err == cudaErrorNotReady)
return false; return false;
mgb_throw(CudaError, "failed to query event: %d: %s", mgb_throw(CudaError, "failed to query event: %d: %s", int(err),
int(err), cudaGetErrorString(err)); cudaGetErrorString(err));
} }
void host_wait_cv() override { void host_wait_cv() override {
MGB_CUDA_CHECK(cudaEventSynchronize(m_cuda_event)); MGB_CUDA_CHECK(cudaEventSynchronize(m_cuda_event));
} }
double do_elapsed_time_until(EventImplHelper &end) override { double do_elapsed_time_until(EventImplHelper& end) override {
m_comp_node_impl->activate(); m_comp_node_impl->activate();
float ret = 0.0; float ret = 0.0;
MGB_CUDA_CHECK(cudaEventElapsedTime(&ret, m_cuda_event, MGB_CUDA_CHECK(cudaEventElapsedTime(
static_cast<EventImpl&>(end).m_cuda_event)); &ret, m_cuda_event, static_cast<EventImpl&>(end).m_cuda_event));
return static_cast<double>(ret) * 1e-3; return static_cast<double>(ret) * 1e-3;
} }
void do_device_wait_by(Impl *cn_impl) override; void do_device_wait_by(Impl* cn_impl) override;
public:
EventImpl(CudaCompNodeImpl *comp_node_impl, size_t create_flags): public:
EventImplHelper(comp_node_impl, create_flags), EventImpl(CudaCompNodeImpl* comp_node_impl, size_t create_flags)
m_comp_node_impl{comp_node_impl} : EventImplHelper(comp_node_impl, create_flags),
{ m_comp_node_impl{comp_node_impl} {
m_comp_node_impl->activate(); m_comp_node_impl->activate();
size_t cuda_flags = cudaEventDisableTiming; size_t cuda_flags = cudaEventDisableTiming;
if (create_flags & NEED_TIMER) if (create_flags & NEED_TIMER)
cuda_flags = 0; cuda_flags = 0;
MGB_CUDA_CHECK(cudaEventCreateWithFlags(&m_cuda_event, cuda_flags)); MGB_CUDA_CHECK(cudaEventCreateWithFlags(&m_cuda_event, cuda_flags));
m_init_finished = true; m_init_finished = true;
} }
~EventImpl() { ~EventImpl() {
if (m_init_finished) { if (m_init_finished) {
MGB_TRY { MGB_TRY { MGB_CUDA_CHECK(cudaEventDestroy(m_cuda_event)); }
MGB_CUDA_CHECK(cudaEventDestroy(m_cuda_event)); MGB_CATCH(MegBrainError & exc, {
} MGB_CATCH(MegBrainError &exc, { mgb_log_error("failed to destroy cuda event: %s", exc.what());
mgb_log_error("failed to destroy cuda event: %s", })
exc.what());
})
}
} }
}
}; };
std::unique_ptr<CompNode::Event> std::unique_ptr<CompNode::Event> CudaCompNodeImpl::create_event(size_t flags) {
CudaCompNodeImpl::create_event(size_t flags) {
return std::make_unique<EventImpl>(this, flags); return std::make_unique<EventImpl>(this, flags);
} }
void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) { void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) {
if (cn_impl->dyn_typeinfo() == CudaCompNodeImpl::typeinfo()) { if (cn_impl->dyn_typeinfo() == CudaCompNodeImpl::typeinfo()) {
auto imp = static_cast<CudaCompNodeImpl*>(cn_impl); auto imp = static_cast<CudaCompNodeImpl*>(cn_impl);
auto stream = imp->m_env.cuda_env().stream; auto stream = imp->m_env.cuda_env().stream;
...@@ -716,7 +694,6 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) { ...@@ -716,7 +694,6 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) {
mgb_throw(MegBrainError, "unimplemented event device_wait_by config"); mgb_throw(MegBrainError, "unimplemented event device_wait_by config");
} }
/* ===================== CudaCompNode static methods ===================== */ /* ===================== CudaCompNode static methods ===================== */
bool CudaCompNode::available() { bool CudaCompNode::available() {
...@@ -729,7 +706,10 @@ bool CudaCompNode::available() { ...@@ -729,7 +706,10 @@ bool CudaCompNode::available() {
result = err == cudaSuccess && ndev > 0; result = err == cudaSuccess && ndev > 0;
if (!result) { if (!result) {
mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", mgb_log_warn("cuda unavailable: %s(%d) ndev=%d",
cudaGetErrorString(err), static_cast<int>(err), ndev); cudaGetErrorString(err), static_cast<int>(err), ndev);
}
if (err == cudaErrorInitializationError) {
mgb_throw(std::runtime_error, "cuda initialization error.");
} }
} }
return result; return result;
...@@ -769,7 +749,7 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator, ...@@ -769,7 +749,7 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator,
"request gpu%d out of valid range [0, %d)", locator.device, "request gpu%d out of valid range [0, %d)", locator.device,
nr_gpu); nr_gpu);
auto &&sdptr = CudaCompNodeImpl::sd; auto&& sdptr = CudaCompNodeImpl::sd;
{ {
MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx); MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx);
if (!sdptr) { if (!sdptr) {
...@@ -777,17 +757,18 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator, ...@@ -777,17 +757,18 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator,
// global finalize // global finalize
using T = CudaCompNodeImpl::StaticData; using T = CudaCompNodeImpl::StaticData;
static std::aligned_storage_t<sizeof(T), alignof(T)> storage; static std::aligned_storage_t<sizeof(T), alignof(T)> storage;
sdptr = new(&storage)T; sdptr = new (&storage) T;
} }
} }
auto &&sd = *sdptr; auto&& sd = *sdptr;
MGB_LOCK_GUARD(sd.mtx); MGB_LOCK_GUARD(sd.mtx);
CompNodeImpl *available_node = nullptr; CompNodeImpl* available_node = nullptr;
for (int i = 0; i < sd.nr_node; ++ i) { for (int i = 0; i < sd.nr_node; ++i) {
auto &&cur = sd.node[i]; auto&& cur = sd.node[i];
if (cur.m_initialized) { if (cur.m_initialized) {
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { if (cur.m_locator == locator &&
cur.m_locator_logical == locator_logical) {
return &cur; return &cur;
} }
} else { } else {
...@@ -797,11 +778,10 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator, ...@@ -797,11 +778,10 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator,
if (!available_node) { if (!available_node) {
mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE, mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE,
"too many CompNode allocated"); "too many CompNode allocated");
available_node = &sd.node[sd.nr_node ++]; available_node = &sd.node[sd.nr_node++];
} }
mgb_assert(locator.device < sd.MAX_NR_DEVICE, mgb_assert(locator.device < sd.MAX_NR_DEVICE, "device number too large");
"device number too large");
mgb_assert(!available_node->m_initialized); mgb_assert(!available_node->m_initialized);
available_node->init(locator, locator_logical); available_node->init(locator, locator_logical);
...@@ -816,13 +796,13 @@ void CudaCompNode::try_coalesce_all_free_memory() { ...@@ -816,13 +796,13 @@ void CudaCompNode::try_coalesce_all_free_memory() {
return; return;
size_t size = 0; size_t size = 0;
for (int i = 0; i < sd->nr_dev_used; ++ i) { for (int i = 0; i < sd->nr_dev_used; ++i) {
size += sd->dev_info[i].mem_alloc-> size += sd->dev_info[i]
gather_stream_free_blk_and_release_full(); .mem_alloc->gather_stream_free_blk_and_release_full();
} }
if (size) { if (size) {
mgb_log_debug("%zu bytes freed by try_coalesce_all_free_memory()", mgb_log_debug("%zu bytes freed by try_coalesce_all_free_memory()",
size); size);
} }
} }
...@@ -831,9 +811,9 @@ void CudaCompNode::sync_all() { ...@@ -831,9 +811,9 @@ void CudaCompNode::sync_all() {
if (!sd) if (!sd)
return; return;
for (int i = 0; ; ++ i) { for (int i = 0;; ++i) {
// ensure async init finished // ensure async init finished
CompNodeEnv *env; CompNodeEnv* env;
{ {
MGB_LOCK_GUARD(sd->mtx); MGB_LOCK_GUARD(sd->mtx);
if (i >= sd->nr_node) { if (i >= sd->nr_node) {
...@@ -851,12 +831,12 @@ void CudaCompNode::sync_all() { ...@@ -851,12 +831,12 @@ void CudaCompNode::sync_all() {
} }
} }
void CudaCompNode::foreach(thin_function<void(CompNode)> callback) { void CudaCompNode::foreach (thin_function<void(CompNode)> callback) {
auto sd = CudaCompNodeImpl::sd; auto sd = CudaCompNodeImpl::sd;
if (!sd) if (!sd)
return; return;
for (int i = 0; ; ++ i) { for (int i = 0;; ++i) {
CompNode cur; CompNode cur;
{ {
MGB_LOCK_GUARD(sd->mtx); MGB_LOCK_GUARD(sd->mtx);
...@@ -875,8 +855,9 @@ size_t CudaCompNode::get_device_count(bool warn) { ...@@ -875,8 +855,9 @@ size_t CudaCompNode::get_device_count(bool warn) {
if (cnt == -1) { if (cnt == -1) {
auto err = cudaGetDeviceCount(&cnt); auto err = cudaGetDeviceCount(&cnt);
if (err != cudaSuccess) { if (err != cudaSuccess) {
if (warn) mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", if (warn)
cudaGetErrorString(err), int(err)); mgb_log_error("cudaGetDeviceCount failed: %s (err %d)",
cudaGetErrorString(err), int(err));
cnt = 0; cnt = 0;
} }
mgb_assert(cnt >= 0); mgb_assert(cnt >= 0);
...@@ -884,26 +865,26 @@ size_t CudaCompNode::get_device_count(bool warn) { ...@@ -884,26 +865,26 @@ size_t CudaCompNode::get_device_count(bool warn) {
return cnt; return cnt;
} }
void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
size_t max_overhead, size_t max_overhead,
double growth_factor) { double growth_factor) {
auto &&sdptr = CudaCompNodeImpl::sd; auto&& sdptr = CudaCompNodeImpl::sd;
{ {
MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx); MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx);
if (!sdptr) { if (!sdptr) {
using T = CudaCompNodeImpl::StaticData; using T = CudaCompNodeImpl::StaticData;
static std::aligned_storage_t<sizeof(T), alignof(T)> storage; static std::aligned_storage_t<sizeof(T), alignof(T)> storage;
sdptr = new(&storage)T; sdptr = new (&storage) T;
sdptr->prealloc_config.alignment = alignment; sdptr->prealloc_config.alignment = alignment;
sdptr->prealloc_config.min_req = min_req; sdptr->prealloc_config.min_req = min_req;
sdptr->prealloc_config.growth_factor = growth_factor; sdptr->prealloc_config.growth_factor = growth_factor;
sdptr->prealloc_config.max_overhead = max_overhead; sdptr->prealloc_config.max_overhead = max_overhead;
} else { } else {
mgb_log_warn( mgb_log_warn(
"invalid call to set_prealloc_config, will fallback to " "invalid call to set_prealloc_config, will fallback to "
"default config; " "default config; "
"prealloc_config should be specified before any CUDA " "prealloc_config should be specified before any CUDA "
"memory allocation"); "memory allocation");
} }
} }
} }
...@@ -913,27 +894,23 @@ void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, ...@@ -913,27 +894,23 @@ void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
bool CudaCompNode::available() { bool CudaCompNode::available() {
return false; return false;
} }
void CudaCompNode::try_coalesce_all_free_memory() { void CudaCompNode::try_coalesce_all_free_memory() {}
} void CudaCompNode::foreach (thin_function<void(CompNode)>) {}
void CudaCompNode::foreach(thin_function<void(CompNode)>) { void CudaCompNode::finalize() {}
}
void CudaCompNode::finalize() {
}
size_t CudaCompNode::get_device_count(bool warn) { size_t CudaCompNode::get_device_count(bool warn) {
return 0; return 0;
} }
CudaCompNode::Impl* CudaCompNode::load_cuda(const Locator&, const Locator&) { CudaCompNode::Impl* CudaCompNode::load_cuda(const Locator&, const Locator&) {
mgb_throw(MegBrainError, "cuda disabled at compile time"); mgb_throw(MegBrainError, "cuda disabled at compile time");
} }
void CudaCompNode::sync_all() { void CudaCompNode::sync_all() {}
}
void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req, void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
size_t max_overhead, size_t max_overhead,
double growth_factor) {} double growth_factor) {}
#undef err #undef err
#endif // MGB_CUDA #endif // MGB_CUDA
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册