提交 3eb529d6 编写于 作者: M Megvii Engine Team

refactor(interpreter): recognize recomp on main thread rather than worker

GitOrigin-RevId: 4ba3942ce475284b2adb14832e3c136aec602016
上级 df976782
......@@ -35,7 +35,6 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
info->desc.comp_node = value.comp_node();
info->desc.value = value.proxy_to_default_cpu();
info->h_value = value;
m_valid_handle.insert(info);
m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) {
sync();
......@@ -49,20 +48,25 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data);
m_valid_handle.insert(info);
return info;
}
void ChannelImpl::del(Handle handle) {
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)});
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
detach_users(info);
info->detach_producer();
m_valid_handle.erase(handle);
m_buffer.enqueue(Del{info});
}
void ChannelImpl::swap_in(Handle handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_buffer.enqueue(SwapIn{reinterpret_cast<TensorInfo*>(handle)});
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(SwapIn{info});
info->evict_type = NONE;
}
}
......@@ -70,7 +74,9 @@ void ChannelImpl::swap_out(Handle handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_buffer.enqueue(SwapOut{reinterpret_cast<TensorInfo*>(handle)});
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(SwapOut{info});
info->evict_type = SWAP;
}
}
......@@ -78,7 +84,13 @@ void ChannelImpl::drop(Handle handle) {
if (m_enable_evict & DROP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_buffer.enqueue(Drop{reinterpret_cast<TensorInfo*>(handle)});
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (!info->producer) {
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", info);
return;
}
info->evict_type = DROP;
m_buffer.enqueue(Drop{info});
}
}
......@@ -134,18 +146,8 @@ void ChannelImpl::dispatch_default_cpu(
output_infos.push_back(info);
outputs->push_back(info);
}
if (m_enable_evict & DROP) {
for (auto out : output_infos) {
out->path.op = op;
for (auto out_ : output_infos) {
out->path.outputs.push_back(m_st.at(out_));
}
for (auto inp : input_infos) {
out->path.inputs.push_back(m_st.at(inp));
inp->path.dep_outputs.push_back(m_st.at(out));
}
}
TensorInfo::ComputePath::make(op, input_infos, output_infos);
}
}
......@@ -168,21 +170,11 @@ void ChannelImpl::dispatch_kernel(
info->h_value = HostTensorND::make_proxy(desc.value)
.proxy_to_comp_node(desc.comp_node);
}
m_valid_handle.insert(info);
cmd.outputs.push_back(info);
outputs->push_back(info);
}
if (m_enable_evict & DROP) {
for (auto out : cmd.outputs) {
out->path.op = cmd.op;
for (auto out_ : cmd.outputs) {
out->path.outputs.push_back(m_st.at(out_));
}
for (auto inp : cmd.inputs) {
out->path.inputs.push_back(m_st.at(inp));
inp->path.dep_outputs.push_back(m_st.at(out));
}
}
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
}
m_buffer.enqueue(std::move(cmd));
if (!validated && m_async_level == 1) {
......@@ -215,6 +207,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
regenerate(info);
}
}
......@@ -233,23 +226,31 @@ SmallVector<Handle> ChannelImpl::apply_op(
}
HostTensorND ChannelImpl::get_value(Handle handle) {
// TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
if (!info->value_fetched) {
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
// donnot use info->value_fetched, it's unsafe
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
TensorPtr tensor_ptr = info->ptr;
auto value_fetched = [&]() {
return tensor_ptr && tensor_ptr->value_fetched();
};
if (!value_fetched()) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
m_waitee = info;
regenerate(info);
m_buffer.enqueue(GetValue{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return info->value_fetched;
// get tensor ptr in lock to ensure safety
tensor_ptr = info->ptr;
return value_fetched();
});
m_waitee = nullptr;
}
mgb_assert(info->ptr->value_fetched());
return info->ptr->get_value();
return tensor_ptr->get_value();
}
TensorShape ChannelImpl::get_shape(Handle handle) {
......@@ -298,6 +299,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
m_waitee = info;
regenerate(info);
m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
......@@ -332,17 +334,12 @@ int ChannelImpl::get_async_level() {
TensorInfo* ChannelImpl::alloc() {
MGB_LOCK_GUARD(m_mutex);
auto info = m_pool.alloc();
m_st.insert(info);
m_valid_handle.insert(info);
return info;
}
void ChannelImpl::free(TensorInfo* ptr) {
MGB_LOCK_GUARD(m_mutex);
if (ptr->path.dep_outputs.size() > 0) {
remove_dep(ptr);
}
m_st.erase(ptr);
mgb_assert(ptr->allow_delete, "delete before ref_cnt = 0");
m_pool.free(ptr);
}
......@@ -350,77 +347,64 @@ ChannelImpl::~ChannelImpl() {
close();
}
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) {
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex)
: std::unique_lock<std::mutex>();
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
if (notice && m_waitee == dest) {
if (m_waitee == dest) {
m_cv.notify_all();
}
}
void ChannelImpl::do_swap_out(TensorInfo* dest) {
void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == DROP) {
mgb_log_warn("the evict type of tensor %p was set to DROP, this SWAP operation will be ignored", dest);
return;
}
if (!dest->ptr) {
return;
recompute(dest->producer);
} else if (dest->evict_type == SWAP) {
swap_in(dest);
}
dest->evict_type = SWAP;
dest->value_fetched = false;
// TODO: swap in parallel
dest->h_value = dest->ptr->get_value();
dest->ptr.reset();
mgb_assert(dest->evict_type == NONE);
}
void ChannelImpl::do_swap_in(TensorInfo* dest) {
if (dest->ptr) {
return;
}
if (dest->h_value.empty()) {
mgb_log_error("backup of the tensor %p not found", dest);
return;
void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
SmallVector<TensorInfo*> workspaces(path->outputs.size(), nullptr);
for (auto&& input: path->inputs) {
regenerate(input);
}
produce_tensor(dest, Tensor::make(dest->h_value), false);
dest->evict_type = NONE;
}
void ChannelImpl::remove_dep(TensorInfo* dest) {
for (auto i : dest->path.dep_outputs) {
auto out_ptr = i.lock();
if (out_ptr) {
regenerate(out_ptr.get(), true);
for (auto&& output: path->outputs) {
if(output == nullptr) {
continue;
}
output->evict_type = NONE;
}
m_buffer.enqueue(ApplyOp{path->op, path->inputs, path->outputs});
}
void ChannelImpl::do_drop(TensorInfo* dest) {
if (dest->evict_type == SWAP) {
mgb_log_warn("the evict type of tensor %p was set to SWAP, this DROP operation will be ignored", dest);
return;
}
if (!dest->path.op) {
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", dest);
return;
}
if (dest->recompute_times >= m_max_recompute_time) {
mgb_log_warn("the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored", dest);
return;
}
if (!dest->ptr) {
return;
void ChannelImpl::detach_users(TensorInfo* dest) {
SmallVector<TensorInfo::ComputePath*> users = dest->users;
for (auto* user: users) {
for (auto* output: user->outputs) {
if (output == nullptr) {
continue;
}
regenerate(output);
output->detach_producer();
}
}
dest->evict_type = DROP;
dest->value_fetched = false;
dest->ptr.reset();
dest->users.clear();
}
void ChannelImpl::set_swap_flag(bool flag) {
if ((!flag) && (m_enable_evict & SWAP)) {
for (auto handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (info->evict_type == SWAP) {
swap_in(info);
}
}
}
if (flag) {
m_enable_evict |= SWAP;
} else {
......@@ -429,6 +413,14 @@ void ChannelImpl::set_swap_flag(bool flag) {
}
void ChannelImpl::set_drop_flag(bool flag) {
if ((!flag) && (m_enable_evict & DROP)) {
for (auto handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (info->evict_type == DROP) {
recompute(info->producer);
}
}
}
if (flag) {
m_enable_evict |= DROP;
} else {
......@@ -440,46 +432,6 @@ void ChannelImpl::set_buffer_length(int length) {
m_buffer.set_capacity(length);
}
void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
if (!info->ptr && info->evict_type != NONE) {
if (info->evict_type == SWAP) {
do_swap_in(info);
} else {
mgb_assert(info->evict_type == DROP);
mgb_assert(info->path.op, "recomputation path not found");
auto path = info->path;
SmallVector<TensorPtr> inputs;
inputs.reserve(path.inputs.size());
for (auto i : path.inputs) {
mgb_assert(i, "invalid history input");
if (!i->ptr) {
regenerate(i.get(), must_drop);
}
inputs.push_back(i->ptr);
}
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
for (size_t i = 0; i < outputs.size(); i ++) {
auto out_ptr = path.outputs[i].lock();
if (out_ptr) {
out_ptr->recompute_times ++;
if (!out_ptr->ptr && out_ptr->evict_type == DROP) {
produce_tensor(out_ptr.get(), std::move(outputs[i]), false);
}
}
}
}
}
if (must_drop) {
if (info->path.op) {
info->path.op.reset();
info->path.inputs.clear();
if (info->evict_type == DROP) {
info->evict_type = NONE;
}
}
}
}
void ChannelImpl::process_one_task(Command& cmd) {
//TODO: remove std::visit for support osx 10.12
std::visit([this](auto& cmd) {
......@@ -493,11 +445,6 @@ void ChannelImpl::process_one_task(Command& cmd) {
tensor_inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) {
if (m_enable_evict && i->evict_type != NONE) {
if (!i->ptr) {
regenerate(i);
}
}
mgb_assert(i->ptr, "Invalid input tensor ptr!");
// refcnt ++, owners: [i->ptr, tensor_inputs]
tensor_inputs.push_back(i->ptr);
......@@ -515,16 +462,14 @@ void ChannelImpl::process_one_task(Command& cmd) {
*cmd.op, std::move(tensor_inputs));
mgb_assert(tensor_outputs.size() == cmd.outputs.size());
for (size_t i = 0; i < tensor_outputs.size(); ++i) {
if (cmd.outputs[i] == nullptr) {
continue;
}
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
}
} else if constexpr (std::is_same_v<T, Del>) {
free(cmd.dest);
} else if constexpr (std::is_same_v<T, GetValue>) {
if (m_enable_evict && cmd.dest->evict_type != NONE) {
if (!cmd.dest->ptr) {
regenerate(cmd.dest);
}
}
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
......@@ -533,11 +478,12 @@ void ChannelImpl::process_one_task(Command& cmd) {
m_cv.notify_all();
}
} else if constexpr (std::is_same_v<T, SwapIn>) {
do_swap_in(cmd.dest);
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
} else if constexpr (std::is_same_v<T, SwapOut>) {
do_swap_out(cmd.dest);
cmd.dest->h_value = cmd.dest->ptr->get_value();
cmd.dest->ptr.reset();
} else if constexpr (std::is_same_v<T, Drop>) {
do_drop(cmd.dest);
cmd.dest->ptr.reset();
} else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src);
......
......@@ -38,22 +38,77 @@ using TensorInfoPtr = std::shared_ptr<TensorInfo>;
struct TensorInfo {
TensorPtr ptr;
LogicalTensorDesc desc;
// FIXME: broken by drop
bool value_fetched = false;
bool invalid = false;
bool allow_delete = false;
EvictType evict_type = NONE;
HostTensorND h_value;
size_t locked = 0;
// reserved for auto drop
size_t pinned = 0;
size_t recompute_times = 0;
struct ComputePath {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfoPtr> inputs;
SmallVector<std::weak_ptr<TensorInfo>> outputs;
SmallVector<std::weak_ptr<TensorInfo>> dep_outputs;
} path;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> unique_inputs;
SmallVector<TensorInfo*> outputs;
size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
}
static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
auto* path = new TensorInfo::ComputePath();
path->op = op;
path->inputs = inputs;
path->outputs = outputs;
// dedup
SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end());
unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
path->unique_inputs = unique_inputs;
// attach users
for (auto input: unique_inputs) {
input->users.push_back(path);
}
// attach producer
for (auto output: outputs) {
output->producer = path;
}
return path;
}
}* producer = nullptr;
void pin() {
++pinned;
}
void unpin() {
--pinned;
}
void detach_producer() {
if (!producer) {
return;
}
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
mgb_assert(output != producer->outputs.end());
*output = nullptr;
if (producer->ref_cnt() == 0) {
for (auto* input: producer->unique_inputs) {
input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
}
delete producer;
}
producer = nullptr;
}
SmallVector<ComputePath*> users;
};
struct Put {
......@@ -186,17 +241,16 @@ struct ChannelImpl : Interpreter::Channel {
private:
TensorInfo* alloc();
void free(TensorInfo*);
void remove_dep(TensorInfo*);
void detach_users(TensorInfo*);
void process_one_task(Command&);
void check_worker_exc_unsafe();
void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
void do_swap_out(TensorInfo* dest);
void do_swap_in(TensorInfo* dest);
void do_drop(TensorInfo* dest);
void regenerate(TensorInfo* dest, bool must_drop);
void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path);
void dispatch_default_cpu(
std::shared_ptr<OpDef> op,
......@@ -235,24 +289,6 @@ private:
ChannelImpl* m_owner;
} m_worker;
struct SharedTensorInfoMap {
void insert(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}});
}
void erase(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
tmap.erase(info);
}
TensorInfoPtr at(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
return tmap.at(info);
}
private:
std::mutex mtx;
std::unordered_map<TensorInfo*, TensorInfoPtr> tmap;
}m_st;
/**
* Buf a command window for following fuse
* example:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册