提交 933dd9a4 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/distributed): add cuda env check before forked thread

style(core/comp_node): reformat code

GitOrigin-RevId: 372452a8eb9e84a2e82d466074f80f78d70531e8
上级 2a541961
......@@ -165,6 +165,18 @@ def _get_device_count_worker(queue, device_type):
queue.put(num)
def _check_device_initialized(device_type: str):
try:
test = Tensor(1, device=device_type)
inited = False
del test
except:
inited = True
errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
if inited:
raise RuntimeError(errmsg)
def get_device_count_by_fork(device_type: str):
"""
Get device count in fork thread.
......
......@@ -15,7 +15,7 @@ from .. import _exit
from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
from .helper import _check_device_initialized, get_device_count_by_fork
from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
......@@ -37,6 +37,7 @@ def _run_wrapped(
queue: mp.Queue,
):
"""Init distributed process group and run wrapped function."""
_check_device_initialized(device_type)
init_process_group(
master_ip=master_ip,
port=port,
......
......@@ -246,3 +246,16 @@ def test_io_remote(shape):
val = np.random.random_sample(shape).astype("float32")
worker(val, shape)
@pytest.mark.require_ngpu(2)
def test_cuda_init_before_fork():
a = mge.tensor(1, device="gpu0")
@dist.launcher(n_gpus=2)
def worker():
a += 1
b = mge.tensor(2)
with pytest.raises(AssertionError):
worker()
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./comp_node.h"
......@@ -21,8 +22,8 @@ using namespace mgb;
#include "megbrain/comp_node/alloc.h"
#include <cstdio>
#include <cctype>
#include <cstdio>
#include <thread>
......@@ -31,26 +32,23 @@ using namespace mgb;
using CudaCompNodeImpl = CudaCompNode::CompNodeImpl;
namespace {
size_t get_min_system_memory(size_t available) {
if (available < (1u << 31)) {
// 225MiB
return 225 * 1024 * 1024;
} else {
// max(300 MiB, 0.05 * available)
return std::max<size_t>(300 * 1024 * 1024, available / 20);
}
size_t get_min_system_memory(size_t available) {
if (available < (1u << 31)) {
// 225MiB
return 225 * 1024 * 1024;
} else {
// max(300 MiB, 0.05 * available)
return std::max<size_t>(300 * 1024 * 1024, available / 20);
}
using CudaHostFunc = megdnn::thin_function<void()>;
void CUDART_CB cuda_host_func_caller(void* ud) {
mgb_assert(ud);
CudaHostFunc* func_ptr = reinterpret_cast<CudaHostFunc*>(ud);
MGB_TRY {
(*func_ptr)();
} MGB_FINALLY(
delete func_ptr;
);
}
} // anonymous namespace
}
using CudaHostFunc = megdnn::thin_function<void()>;
void CUDART_CB cuda_host_func_caller(void* ud) {
mgb_assert(ud);
CudaHostFunc* func_ptr = reinterpret_cast<CudaHostFunc*>(ud);
MGB_TRY { (*func_ptr)(); }
MGB_FINALLY(delete func_ptr;);
}
} // anonymous namespace
namespace mgb {
namespace mem_alloc {
......@@ -103,7 +101,8 @@ class CudaHostAllocator : public RawAllocator {
public:
void* alloc(size_t size) override {
void* addr;
cudaError_t cuda_error = cudaHostAlloc(&addr, size, cudaHostAllocDefault);
cudaError_t cuda_error =
cudaHostAlloc(&addr, size, cudaHostAllocDefault);
if (cuda_error == cudaSuccess) {
mgb_assert(addr);
return addr;
......@@ -162,7 +161,7 @@ std::unique_ptr<DevMemAlloc> DevMemAlloc::make_cuda_alloc() {
} // namespace mgb
/* ===================== CudaCompNodeImpl ===================== */
class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
class CudaCompNode::CompNodeImpl final : public CompNode::Impl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
friend class EventImpl;
......@@ -170,7 +169,7 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
struct DeviceInfo;
struct StaticData;
static StaticData *sd;
static StaticData* sd;
static Spinlock sd_mtx;
#if !MGB_BUILD_SLIM_SERVING
std::mutex m_update_mem;
......@@ -180,17 +179,15 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
//! failed
bool m_initialized = false;
Locator m_locator, m_locator_logical;
mem_alloc::StreamMemAlloc *m_mem_alloc;
DeviceInfo *m_device_info;
mem_alloc::StreamMemAlloc* m_mem_alloc;
DeviceInfo* m_device_info;
std::unique_ptr<Event> m_sync_event;
Spinlock m_sync_event_mtx;
void activate() {
m_env.cuda_env().activate();
}
void activate() { m_env.cuda_env().activate(); }
void init(const Locator &locator, const Locator &locator_logical);
void init(const Locator& locator, const Locator& locator_logical);
void fini();
//! return whether global finalized, and print warning in such case
......@@ -207,117 +204,111 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
static_cast<CompNodeImpl*>(self)->free_host(ptr);
}
public:
CompNodeImpl() : Impl(static_free_device, static_free_host) {}
public:
CompNodeImpl() : Impl(static_free_device, static_free_host) {}
void* alloc_device(size_t size) override {
activate();
void* alloc_device(size_t size) override {
activate();
#if MGB_BUILD_SLIM_SERVING
return m_mem_alloc->alloc(size);
return m_mem_alloc->alloc(size);
#else
void* ptr = m_mem_alloc->alloc(size);
{
MGB_LOCK_GUARD(m_update_mem);
ptr2size[ptr] = size;
m_used_mem += size;
}
return ptr;
#endif
void* ptr = m_mem_alloc->alloc(size);
{
MGB_LOCK_GUARD(m_update_mem);
ptr2size[ptr] = size;
m_used_mem += size;
}
return ptr;
#endif
}
void free_device(void *ptr);
void free_device(void* ptr);
void *alloc_host(size_t size) override;
void* alloc_host(size_t size) override;
void free_host(void *ptr);
void free_host(void* ptr);
void copy_to_host(void *host_ptr,
const void *device_ptr, size_t size) override {
activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(host_ptr, device_ptr, size,
cudaMemcpyDeviceToHost, m_env.cuda_env().stream));
}
void copy_to_host(void* host_ptr, const void* device_ptr,
size_t size) override {
activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(host_ptr, device_ptr, size,
cudaMemcpyDeviceToHost,
m_env.cuda_env().stream));
}
void copy_to_device(void *device_ptr,
const void *host_ptr, size_t size) override {
activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(device_ptr, host_ptr, size,
cudaMemcpyHostToDevice, m_env.cuda_env().stream));
}
void copy_to_device(void* device_ptr, const void* host_ptr,
size_t size) override {
activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(device_ptr, host_ptr, size,
cudaMemcpyHostToDevice,
m_env.cuda_env().stream));
}
void peer_copy_to(
Impl *dest_impl, void *dest,
const void *src, size_t size) override;
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override;
size_t get_mem_addr_alignment() override {
return m_env.property().mem_alignment;
}
size_t get_mem_addr_alignment() override {
return m_env.property().mem_alignment;
}
std::unique_ptr<Event> create_event(size_t flags) override;
std::unique_ptr<Event> create_event(size_t flags) override;
void sync() override;
void sync() override;
MemNode mem_node() override;
MemNode mem_node() override;
std::pair<size_t, size_t> get_mem_status_bytes() override {
// explicitly call cuda_env() to ensure async init is finished
m_env.cuda_env().activate();
size_t tot, free;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
free += m_mem_alloc->get_free_memory_dev().tot;
return {tot, free};
}
std::pair<size_t, size_t> get_mem_status_bytes() override {
// explicitly call cuda_env() to ensure async init is finished
m_env.cuda_env().activate();
size_t tot, free;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
free += m_mem_alloc->get_free_memory_dev().tot;
return {tot, free};
}
#if !MGB_BUILD_SLIM_SERVING
std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override {
return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr);
}
std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr,
size_t end_ptr) override {
return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr);
}
#endif
Locator locator() override {
return m_locator;
}
Locator locator() override { return m_locator; }
Locator locator_logical() override {
return m_locator_logical;
}
Locator locator_logical() override { return m_locator_logical; }
void add_callback(CudaHostFunc&& cb) override {
void add_callback(CudaHostFunc&& cb) override {
#if CUDART_VERSION >= 10000
activate();
CudaHostFunc* func_ptr = new CudaHostFunc(std::move(cb));
MGB_TRY {
MGB_CUDA_CHECK(cudaLaunchHostFunc(m_env.cuda_env().stream,
cuda_host_func_caller, static_cast<void*>(func_ptr)));
} MGB_CATCH(..., {
delete func_ptr;
throw;
});
activate();
CudaHostFunc* func_ptr = new CudaHostFunc(std::move(cb));
MGB_TRY {
MGB_CUDA_CHECK(cudaLaunchHostFunc(m_env.cuda_env().stream,
cuda_host_func_caller,
static_cast<void*>(func_ptr)));
}
MGB_CATCH(..., {
delete func_ptr;
throw;
});
#else
MGB_MARK_USED_VAR(cb);
MGB_MARK_USED_VAR(cuda_host_func_caller);
mgb_throw(
MegBrainError,
"add_callback only support in cuda10.0 and later version");
MGB_MARK_USED_VAR(cb);
MGB_MARK_USED_VAR(cuda_host_func_caller);
mgb_throw(MegBrainError,
"add_callback only support in cuda10.0 and later version");
#endif
}
}
uint64_t get_uid() override {
return m_uid;
}
uint64_t get_uid() override { return m_uid; }
#if !MGB_BUILD_SLIM_SERVING
size_t get_used_memory() override {
return m_used_mem;
}
size_t get_used_memory() override { return m_used_mem; }
#endif
private:
uint64_t m_uid;
private:
uint64_t m_uid;
#if !MGB_BUILD_SLIM_SERVING
std::unordered_map<void*, size_t> ptr2size;
size_t m_used_mem = 0;
std::unordered_map<void*, size_t> ptr2size;
size_t m_used_mem = 0;
#endif
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CudaCompNode::CompNodeImpl);
......@@ -326,15 +317,11 @@ struct CudaCompNodeImpl::DeviceInfo {
int dev_num = -1;
std::unique_ptr<mem_alloc::DevMemAlloc> mem_alloc;
bool init_done() const {
return mem_alloc.get();
}
bool init_done() const { return mem_alloc.get(); }
void init(const CompNodeEnv &env);
void init(const CompNodeEnv& env);
void fini() {
mem_alloc.reset();
}
void fini() { mem_alloc.reset(); }
};
struct CudaCompNodeImpl::StaticData {
......@@ -347,21 +334,21 @@ struct CudaCompNodeImpl::StaticData {
std::unique_ptr<mem_alloc::SimpleCachingAlloc> host_alloc;
CudaCompNode::CompNodeImpl node[MAX_NR_COMP_NODE];
DeviceInfo dev_info[MAX_NR_DEVICE];
int nr_node = 0, //!< number of loaded node[]
nr_dev_used = 0; //!< number of used dev_info[]
int nr_node = 0, //!< number of loaded node[]
nr_dev_used = 0; //!< number of used dev_info[]
StaticData() : host_alloc(
mem_alloc::SimpleCachingAlloc::make(
std::make_unique<mem_alloc::CudaHostAllocator>())) {
StaticData()
: host_alloc(mem_alloc::SimpleCachingAlloc::make(
std::make_unique<mem_alloc::CudaHostAllocator>())) {
prealloc_config.max_overhead = 0;
prealloc_config.alignment = 1;
host_alloc->alignment(1);
}
~StaticData() {
for (int i = 0; i < nr_node; ++ i)
for (int i = 0; i < nr_node; ++i)
node[i].fini();
for (int i = 0; i < nr_dev_used; ++ i)
for (int i = 0; i < nr_dev_used; ++i)
dev_info[i].fini();
}
......@@ -382,21 +369,21 @@ struct CudaCompNodeImpl::StaticData {
CudaCompNodeImpl::StaticData* CudaCompNodeImpl::sd = nullptr;
Spinlock CudaCompNodeImpl::sd_mtx;
void CudaCompNodeImpl::init(
const Locator &locator, const Locator &locator_logical) {
void CudaCompNodeImpl::init(const Locator& locator,
const Locator& locator_logical) {
m_locator = locator;
m_locator_logical = locator_logical;
m_initialized = true;
#if defined(__linux__) || defined(TARGET_OS_MAC)
FILE *fp;
FILE* fp;
fp = fopen("/dev/urandom", "r");
mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1);
fclose(fp);
#else
m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch()
).count();
std::chrono::system_clock::now().time_since_epoch())
.count();
#endif
auto on_succ = [this](cudaStream_t stream) {
......@@ -404,8 +391,8 @@ void CudaCompNodeImpl::init(
log_comp_node_created(locator, m_locator_logical);
MGB_LOCK_GUARD(sd->mtx);
DeviceInfo *dev_info = nullptr;
for (int i = 0; i < sd->nr_dev_used; ++ i) {
DeviceInfo* dev_info = nullptr;
for (int i = 0; i < sd->nr_dev_used; ++i) {
if (sd->dev_info[i].dev_num == locator.device) {
dev_info = &sd->dev_info[i];
break;
......@@ -416,7 +403,7 @@ void CudaCompNodeImpl::init(
dev_info = &sd->dev_info[sd->nr_dev_used];
dev_info->init(m_env);
// note: add nr_dev_used only after init succeeds
++ sd->nr_dev_used;
++sd->nr_dev_used;
}
m_device_info = dev_info;
m_mem_alloc =
......@@ -428,9 +415,8 @@ void CudaCompNodeImpl::init(
m_initialized = false;
};
m_env.init_cuda_async(
locator.device, make_comp_node_from_impl(this),
{on_succ, on_error});
m_env.init_cuda_async(locator.device, make_comp_node_from_impl(this),
{on_succ, on_error});
}
void CudaCompNodeImpl::fini() {
......@@ -444,7 +430,7 @@ void CudaCompNodeImpl::fini() {
m_initialized = false;
}
void CudaCompNodeImpl::free_device(void *ptr) {
void CudaCompNodeImpl::free_device(void* ptr) {
if (check_global_finalized())
return;
......@@ -452,13 +438,13 @@ void CudaCompNodeImpl::free_device(void *ptr) {
#if !MGB_BUILD_SLIM_SERVING
{
MGB_LOCK_GUARD(m_update_mem);
mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!", ptr);
mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!",
ptr);
m_used_mem -= ptr2size.at(ptr);
ptr2size.erase(ptr);
}
#endif
m_mem_alloc->free(ptr);
}
void* CudaCompNodeImpl::alloc_host(size_t size) {
......@@ -468,38 +454,37 @@ void* CudaCompNodeImpl::alloc_host(size_t size) {
}
void CudaCompNodeImpl::free_host(void* ptr) {
if (check_global_finalized()) return;
if (check_global_finalized())
return;
sd->host_alloc->free(ptr);
}
void CudaCompNodeImpl::peer_copy_to(
Impl *dest_impl, void *dest, const void *src, size_t size) {
void CudaCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest,
const void* src, size_t size) {
if (dest_impl->same_type<CudaCompNodeImpl>()) {
auto &&dst_env = static_cast<CudaCompNodeImpl*>(
dest_impl)->m_env.cuda_env();
auto &&src_env = m_env.cuda_env();
auto&& dst_env =
static_cast<CudaCompNodeImpl*>(dest_impl)->m_env.cuda_env();
auto&& src_env = m_env.cuda_env();
activate();
if (dst_env.device == src_env.device) {
MGB_CUDA_CHECK(cudaMemcpyAsync(dest, src, size,
cudaMemcpyDeviceToDevice,
dst_env.stream));
MGB_CUDA_CHECK(cudaMemcpyAsync(
dest, src, size, cudaMemcpyDeviceToDevice, dst_env.stream));
} else {
enable_peer_access(src_env.device, dst_env.device);
enable_peer_access(dst_env.device, src_env.device);
MGB_CUDA_CHECK(cudaMemcpyPeerAsync(
dest, dst_env.device,
src, src_env.device, size,
dst_env.stream));
MGB_CUDA_CHECK(cudaMemcpyPeerAsync(dest, dst_env.device, src,
src_env.device, size,
dst_env.stream));
}
return;
}
mgb_assert(dest_impl->env().property().type == DeviceType::CPU,
"cuda peer_copy_to only implemented for CPU");
"cuda peer_copy_to only implemented for CPU");
auto copy = [this, dest, src, size]() {
auto stream = m_env.cuda_env().stream;
m_env.cuda_env().activate();
MGB_CUDA_CHECK(cudaMemcpyAsync(
dest, src, size, cudaMemcpyDeviceToHost, stream));
MGB_CUDA_CHECK(cudaMemcpyAsync(dest, src, size, cudaMemcpyDeviceToHost,
stream));
MGB_CUDA_CHECK(cudaStreamSynchronize(stream));
};
dest_impl->env().cpu_env().dispatch(copy);
......@@ -514,13 +499,13 @@ MemNode CudaCompNodeImpl::mem_node() {
void CudaCompNodeImpl::sync() {
activate();
// do not use MGB_CUDA_CHECK(cudaStreamSynchronize(m_env->stream)) since other
// threads may be adding operations into the stream, and we only care about
// previous operations in current thread. However docs of
// do not use MGB_CUDA_CHECK(cudaStreamSynchronize(m_env->stream)) since
// other threads may be adding operations into the stream, and we only care
// about previous operations in current thread. However docs of
// cudaStreamSynchronize did not describe details of such condition, so we
// use manual event implementation
Event *event;
Event* event;
{
MGB_LOCK_GUARD(m_sync_event_mtx);
if (!m_sync_event)
......@@ -532,8 +517,8 @@ void CudaCompNodeImpl::sync() {
}
void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) {
static bool already_enabled[
StaticData::MAX_NR_DEVICE][StaticData::MAX_NR_DEVICE];
static bool already_enabled[StaticData::MAX_NR_DEVICE]
[StaticData::MAX_NR_DEVICE];
if (already_enabled[dev0][dev1])
return;
......@@ -550,7 +535,8 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) {
auto err = cudaDeviceEnablePeerAccess(dev1, 0);
if (err != cudaSuccess) {
mgb_log_error("failed to enable peer access from %d to %d: %s(%d)",
dev0, dev1, cudaGetErrorString(err), static_cast<int>(err));
dev0, dev1, cudaGetErrorString(err),
static_cast<int>(err));
cudaGetLastError();
}
}
......@@ -563,33 +549,29 @@ void CudaCompNodeImpl::enable_peer_access(int dev0, int dev1) {
MGB_CUDA_CHECK(cudaMalloc(&dp0, sizeof(int)));
MGB_CUDA_CHECK(cudaSetDevice(dev1));
MGB_CUDA_CHECK(cudaMalloc(&dp1, sizeof(int)));
MGB_CUDA_CHECK(cudaMemcpy(dp0, &v0, sizeof(int),
cudaMemcpyHostToDevice));
MGB_CUDA_CHECK(cudaMemcpy(dp1, &v1, sizeof(int),
cudaMemcpyHostToDevice));
MGB_CUDA_CHECK(cudaMemcpy(dp0, &v0, sizeof(int), cudaMemcpyHostToDevice));
MGB_CUDA_CHECK(cudaMemcpy(dp1, &v1, sizeof(int), cudaMemcpyHostToDevice));
MGB_CUDA_CHECK(cudaMemcpyPeer(dp1, dev1, dp0, dev0, sizeof(int)));
int get = 0;
MGB_CUDA_CHECK(cudaMemcpy(&get, dp1, sizeof(int),
cudaMemcpyDeviceToHost));
MGB_CUDA_CHECK(cudaMemcpy(&get, dp1, sizeof(int), cudaMemcpyDeviceToHost));
mgb_throw_if(get != 1, CudaError,
"P2P copy (%d => %d) check failed; consider disabling "
"Access Control Services(ACS) for the PCI device",
dev0, dev1);
"P2P copy (%d => %d) check failed; consider disabling "
"Access Control Services(ACS) for the PCI device",
dev0, dev1);
already_enabled[dev0][dev1] = true;
}
/* ===================== CudaCompNodeImpl::DeviceInfo ===================== */
void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) {
void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv& env) {
mgb_assert(!mem_alloc);
#if 0
// forward cudaMalloc
mem_alloc = mem_alloc::DevMemAlloc::make_cuda_alloc();
#else
auto &&cuenv = env.cuda_env();
auto&& cuenv = env.cuda_env();
cuenv.activate();
dev_num = cuenv.device;
auto reserve_size = StaticData::get_mem_reserve_size();
......@@ -600,9 +582,10 @@ void CudaCompNodeImpl::DeviceInfo::init(const CompNodeEnv &env) {
mem_alloc->prealloc_config(sd->prealloc_config);
auto align = env.property().mem_alignment;
mem_alloc->alignment(align);
mgb_log_debug("cuda: gpu%d: name=`%s' dyn_mem_reserve=%.2fMiB alignment=0x%zx",
dev_num, cuenv.device_prop.name,
reserve_size / 1024.0 / 1024, align);
mgb_log_debug(
"cuda: gpu%d: name=`%s' dyn_mem_reserve=%.2fMiB alignment=0x%zx",
dev_num, cuenv.device_prop.name, reserve_size / 1024.0 / 1024,
align);
#endif
}
......@@ -631,14 +614,14 @@ bool CudaCompNodeImpl::check_global_finalized() {
/* ===================== EventImpl ===================== */
class CudaCompNode::EventImpl final: public EventImplHelper {
class CudaCompNode::EventImpl final : public EventImplHelper {
bool m_init_finished = false;
CudaCompNodeImpl * const m_comp_node_impl;
CudaCompNodeImpl* const m_comp_node_impl;
cudaEvent_t m_cuda_event;
void do_record() override {
m_comp_node_impl->activate();
auto &&env = m_comp_node_impl->m_env.cuda_env();
auto&& env = m_comp_node_impl->m_env.cuda_env();
MGB_CUDA_CHECK(cudaEventRecord(m_cuda_event, env.stream));
}
......@@ -649,56 +632,51 @@ class CudaCompNode::EventImpl final: public EventImplHelper {
return true;
if (err == cudaErrorNotReady)
return false;
mgb_throw(CudaError, "failed to query event: %d: %s",
int(err), cudaGetErrorString(err));
mgb_throw(CudaError, "failed to query event: %d: %s", int(err),
cudaGetErrorString(err));
}
void host_wait_cv() override {
MGB_CUDA_CHECK(cudaEventSynchronize(m_cuda_event));
}
double do_elapsed_time_until(EventImplHelper &end) override {
double do_elapsed_time_until(EventImplHelper& end) override {
m_comp_node_impl->activate();
float ret = 0.0;
MGB_CUDA_CHECK(cudaEventElapsedTime(&ret, m_cuda_event,
static_cast<EventImpl&>(end).m_cuda_event));
MGB_CUDA_CHECK(cudaEventElapsedTime(
&ret, m_cuda_event, static_cast<EventImpl&>(end).m_cuda_event));
return static_cast<double>(ret) * 1e-3;
}
void do_device_wait_by(Impl *cn_impl) override;
public:
void do_device_wait_by(Impl* cn_impl) override;
EventImpl(CudaCompNodeImpl *comp_node_impl, size_t create_flags):
EventImplHelper(comp_node_impl, create_flags),
m_comp_node_impl{comp_node_impl}
{
m_comp_node_impl->activate();
size_t cuda_flags = cudaEventDisableTiming;
if (create_flags & NEED_TIMER)
cuda_flags = 0;
MGB_CUDA_CHECK(cudaEventCreateWithFlags(&m_cuda_event, cuda_flags));
m_init_finished = true;
}
~EventImpl() {
if (m_init_finished) {
MGB_TRY {
MGB_CUDA_CHECK(cudaEventDestroy(m_cuda_event));
} MGB_CATCH(MegBrainError &exc, {
mgb_log_error("failed to destroy cuda event: %s",
exc.what());
})
}
public:
EventImpl(CudaCompNodeImpl* comp_node_impl, size_t create_flags)
: EventImplHelper(comp_node_impl, create_flags),
m_comp_node_impl{comp_node_impl} {
m_comp_node_impl->activate();
size_t cuda_flags = cudaEventDisableTiming;
if (create_flags & NEED_TIMER)
cuda_flags = 0;
MGB_CUDA_CHECK(cudaEventCreateWithFlags(&m_cuda_event, cuda_flags));
m_init_finished = true;
}
~EventImpl() {
if (m_init_finished) {
MGB_TRY { MGB_CUDA_CHECK(cudaEventDestroy(m_cuda_event)); }
MGB_CATCH(MegBrainError & exc, {
mgb_log_error("failed to destroy cuda event: %s", exc.what());
})
}
}
};
std::unique_ptr<CompNode::Event>
CudaCompNodeImpl::create_event(size_t flags) {
std::unique_ptr<CompNode::Event> CudaCompNodeImpl::create_event(size_t flags) {
return std::make_unique<EventImpl>(this, flags);
}
void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) {
void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) {
if (cn_impl->dyn_typeinfo() == CudaCompNodeImpl::typeinfo()) {
auto imp = static_cast<CudaCompNodeImpl*>(cn_impl);
auto stream = imp->m_env.cuda_env().stream;
......@@ -716,7 +694,6 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl *cn_impl) {
mgb_throw(MegBrainError, "unimplemented event device_wait_by config");
}
/* ===================== CudaCompNode static methods ===================== */
bool CudaCompNode::available() {
......@@ -729,7 +706,10 @@ bool CudaCompNode::available() {
result = err == cudaSuccess && ndev > 0;
if (!result) {
mgb_log_warn("cuda unavailable: %s(%d) ndev=%d",
cudaGetErrorString(err), static_cast<int>(err), ndev);
cudaGetErrorString(err), static_cast<int>(err), ndev);
}
if (err == cudaErrorInitializationError) {
mgb_throw(std::runtime_error, "cuda initialization error.");
}
}
return result;
......@@ -769,7 +749,7 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator,
"request gpu%d out of valid range [0, %d)", locator.device,
nr_gpu);
auto &&sdptr = CudaCompNodeImpl::sd;
auto&& sdptr = CudaCompNodeImpl::sd;
{
MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx);
if (!sdptr) {
......@@ -777,17 +757,18 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator,
// global finalize
using T = CudaCompNodeImpl::StaticData;
static std::aligned_storage_t<sizeof(T), alignof(T)> storage;
sdptr = new(&storage)T;
sdptr = new (&storage) T;
}
}
auto &&sd = *sdptr;
auto&& sd = *sdptr;
MGB_LOCK_GUARD(sd.mtx);
CompNodeImpl *available_node = nullptr;
for (int i = 0; i < sd.nr_node; ++ i) {
auto &&cur = sd.node[i];
CompNodeImpl* available_node = nullptr;
for (int i = 0; i < sd.nr_node; ++i) {
auto&& cur = sd.node[i];
if (cur.m_initialized) {
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) {
if (cur.m_locator == locator &&
cur.m_locator_logical == locator_logical) {
return &cur;
}
} else {
......@@ -797,11 +778,10 @@ CompNode::Impl* CudaCompNode::load_cuda(const Locator& locator,
if (!available_node) {
mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE,
"too many CompNode allocated");
available_node = &sd.node[sd.nr_node ++];
"too many CompNode allocated");
available_node = &sd.node[sd.nr_node++];
}
mgb_assert(locator.device < sd.MAX_NR_DEVICE,
"device number too large");
mgb_assert(locator.device < sd.MAX_NR_DEVICE, "device number too large");
mgb_assert(!available_node->m_initialized);
available_node->init(locator, locator_logical);
......@@ -816,13 +796,13 @@ void CudaCompNode::try_coalesce_all_free_memory() {
return;
size_t size = 0;
for (int i = 0; i < sd->nr_dev_used; ++ i) {
size += sd->dev_info[i].mem_alloc->
gather_stream_free_blk_and_release_full();
for (int i = 0; i < sd->nr_dev_used; ++i) {
size += sd->dev_info[i]
.mem_alloc->gather_stream_free_blk_and_release_full();
}
if (size) {
mgb_log_debug("%zu bytes freed by try_coalesce_all_free_memory()",
size);
size);
}
}
......@@ -831,9 +811,9 @@ void CudaCompNode::sync_all() {
if (!sd)
return;
for (int i = 0; ; ++ i) {
for (int i = 0;; ++i) {
// ensure async init finished
CompNodeEnv *env;
CompNodeEnv* env;
{
MGB_LOCK_GUARD(sd->mtx);
if (i >= sd->nr_node) {
......@@ -851,12 +831,12 @@ void CudaCompNode::sync_all() {
}
}
void CudaCompNode::foreach(thin_function<void(CompNode)> callback) {
void CudaCompNode::foreach (thin_function<void(CompNode)> callback) {
auto sd = CudaCompNodeImpl::sd;
if (!sd)
return;
for (int i = 0; ; ++ i) {
for (int i = 0;; ++i) {
CompNode cur;
{
MGB_LOCK_GUARD(sd->mtx);
......@@ -875,8 +855,9 @@ size_t CudaCompNode::get_device_count(bool warn) {
if (cnt == -1) {
auto err = cudaGetDeviceCount(&cnt);
if (err != cudaSuccess) {
if (warn) mgb_log_error("cudaGetDeviceCount failed: %s (err %d)",
cudaGetErrorString(err), int(err));
if (warn)
mgb_log_error("cudaGetDeviceCount failed: %s (err %d)",
cudaGetErrorString(err), int(err));
cnt = 0;
}
mgb_assert(cnt >= 0);
......@@ -884,26 +865,26 @@ size_t CudaCompNode::get_device_count(bool warn) {
return cnt;
}
void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
size_t max_overhead,
double growth_factor) {
auto &&sdptr = CudaCompNodeImpl::sd;
auto&& sdptr = CudaCompNodeImpl::sd;
{
MGB_LOCK_GUARD(CudaCompNodeImpl::sd_mtx);
if (!sdptr) {
using T = CudaCompNodeImpl::StaticData;
static std::aligned_storage_t<sizeof(T), alignof(T)> storage;
sdptr = new(&storage)T;
sdptr = new (&storage) T;
sdptr->prealloc_config.alignment = alignment;
sdptr->prealloc_config.min_req = min_req;
sdptr->prealloc_config.growth_factor = growth_factor;
sdptr->prealloc_config.max_overhead = max_overhead;
} else {
mgb_log_warn(
"invalid call to set_prealloc_config, will fallback to "
"default config; "
"prealloc_config should be specified before any CUDA "
"memory allocation");
"invalid call to set_prealloc_config, will fallback to "
"default config; "
"prealloc_config should be specified before any CUDA "
"memory allocation");
}
}
}
......@@ -913,27 +894,23 @@ void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
bool CudaCompNode::available() {
return false;
}
void CudaCompNode::try_coalesce_all_free_memory() {
}
void CudaCompNode::foreach(thin_function<void(CompNode)>) {
}
void CudaCompNode::finalize() {
}
void CudaCompNode::try_coalesce_all_free_memory() {}
void CudaCompNode::foreach (thin_function<void(CompNode)>) {}
void CudaCompNode::finalize() {}
size_t CudaCompNode::get_device_count(bool warn) {
return 0;
}
CudaCompNode::Impl* CudaCompNode::load_cuda(const Locator&, const Locator&) {
mgb_throw(MegBrainError, "cuda disabled at compile time");
}
void CudaCompNode::sync_all() {
}
void CudaCompNode::sync_all() {}
void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
void CudaCompNode::set_prealloc_config(size_t alignment, size_t min_req,
size_t max_overhead,
double growth_factor) {}
#undef err
#endif // MGB_CUDA
#endif // MGB_CUDA
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册