提交 d7c546c9 编写于 作者: M Megvii Engine Team

fix(mge/interpreter): regenerates tensor when its dev value is needed

GitOrigin-RevId: ed26d52ee4382d3e4d7c02e1d4a612b62393cb5f
上级 1f7bf1ad
......@@ -90,6 +90,18 @@ class ResNet(M.Module):
return out
def run_dtr_drop_copy_dev_tensor():
mge.dtr.evictee_minimum_size = 128
mge.dtr.enable()
x = F.ones((10, 100))
x._drop()
x[...] = mge.tensor(x, no_cache=True)
x.numpy()
mge.dtr.evictee_minimum_size = 1024 ** 2
mge.dtr.disable()
mge._exit(0)
def run_dtr_resnet1202():
batch_size = 6
resnet1202 = ResNet(BasicBlock, [200, 200, 200])
......@@ -135,3 +147,12 @@ def test_dtr_resnet1202():
p.start()
p.join()
assert p.exitcode == 0
@pytest.mark.require_ngpu(1)
@pytest.mark.isolated_distributed
def test_dtr_drop_copy_dev_tensor():
p = mp.Process(target=run_dtr_drop_copy_dev_tensor)
p.start()
p.join()
assert p.exitcode == 0
......@@ -136,9 +136,31 @@ struct PopScope {
const char* get_name() const { return "PopScope"; }
};
struct StartRegen {
TensorInfo* dest;
template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}
const char* get_name() const { return "StartRegen"; }
};
struct StopRegen {
TensorInfo* dest;
template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("dest", dest);
}
const char* get_name() const { return "StopRegen"; }
};
using CommandData = std::variant<
Put, ApplyOp, Del, GetValue, Drop, SetOption, StartProfile, StopProfile,
PushScope, PopScope>;
PushScope, PopScope, StartRegen, StopRegen>;
struct Command {
uint64_t id;
......
......@@ -1002,8 +1002,11 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
m_waitee_id = Profiler::next_id();
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
bool require_host = prop == TensorProp::HostValue;
bool require_dev = prop == TensorProp::DevValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
auto dev_available = [&] { return info->ptr; };
bool wait_host = false;
bool wait_regen = false;
if (require_host && !host_available()) {
// avoid dead lock
lock.unlock();
......@@ -1020,16 +1023,52 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
lock.lock();
wait_host = true;
}
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return require_host ? host_available() : static_cast<bool>(info->ptr);
});
if (require_dev && !dev_available()) {
lock.unlock();
if (Profiler::is_profiling()) {
m_worker.add_task(
{Profiler::next_id(), StartRegen{info},
get_channel_state().stack_manager.dump()});
} else {
m_worker.add_task({
Profiler::next_id(),
StartRegen{info},
});
}
lock.lock();
wait_regen = true;
}
if (require_dev) {
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return dev_available();
});
} else {
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return require_host ? host_available() : static_cast<bool>(info->ptr);
});
}
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr;
if (wait_host) {
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
if (wait_regen) {
lock.unlock();
if (Profiler::is_profiling()) {
m_worker.add_task(
{Profiler::next_id(), StopRegen{info},
get_channel_state().stack_manager.dump()});
} else {
m_worker.add_task({
Profiler::next_id(),
StopRegen{info},
});
}
lock.lock();
}
return info->ptr;
}
......@@ -1254,6 +1293,17 @@ void ChannelImpl::process_one_task(Command& icmd) {
MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
} else if constexpr (std::is_same_v<T, PopScope>) {
MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
} else if constexpr (std::is_same_v<T, StartRegen>) {
if (cmd.dest->invalid)
return;
cmd.dest->pin();
if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
regenerate(cmd.dest);
}
MGB_LOCK_GUARD(m_mutex);
notify_tensor_unsafe(cmd.dest);
} else if constexpr (std::is_same_v<T, StopRegen>) {
cmd.dest->unpin();
} else {
static_assert(!std::is_same_v<T, T>);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册