提交 3eb529d6 编写于 作者: M Megvii Engine Team

refactor(interpreter): recognize recomp on main thread rather than worker

GitOrigin-RevId: 4ba3942ce475284b2adb14832e3c136aec602016
上级 df976782
...@@ -35,7 +35,6 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { ...@@ -35,7 +35,6 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
info->desc.comp_node = value.comp_node(); info->desc.comp_node = value.comp_node();
info->desc.value = value.proxy_to_default_cpu(); info->desc.value = value.proxy_to_default_cpu();
info->h_value = value; info->h_value = value;
m_valid_handle.insert(info);
m_buffer.enqueue(Put{info, value, no_cache}); m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) { if (m_async_level == 0) {
sync(); sync();
...@@ -49,20 +48,25 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { ...@@ -49,20 +48,25 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info->desc.layout = data.layout(); info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node(); info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data); info->ptr = Tensor::make(data);
m_valid_handle.insert(info);
return info; return info;
} }
void ChannelImpl::del(Handle handle) { void ChannelImpl::del(Handle handle) {
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); auto* info = reinterpret_cast<TensorInfo*>(handle);
detach_users(info);
info->detach_producer();
m_valid_handle.erase(handle);
m_buffer.enqueue(Del{info});
} }
void ChannelImpl::swap_in(Handle handle) { void ChannelImpl::swap_in(Handle handle) {
if (m_enable_evict & SWAP) { if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
m_buffer.enqueue(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(SwapIn{info});
info->evict_type = NONE;
} }
} }
...@@ -70,7 +74,9 @@ void ChannelImpl::swap_out(Handle handle) { ...@@ -70,7 +74,9 @@ void ChannelImpl::swap_out(Handle handle) {
if (m_enable_evict & SWAP) { if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
m_buffer.enqueue(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(SwapOut{info});
info->evict_type = SWAP;
} }
} }
...@@ -78,7 +84,13 @@ void ChannelImpl::drop(Handle handle) { ...@@ -78,7 +84,13 @@ void ChannelImpl::drop(Handle handle) {
if (m_enable_evict & DROP) { if (m_enable_evict & DROP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
m_buffer.enqueue(Drop{reinterpret_cast<TensorInfo*>(handle)}); auto* info = reinterpret_cast<TensorInfo*>(handle);
if (!info->producer) {
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", info);
return;
}
info->evict_type = DROP;
m_buffer.enqueue(Drop{info});
} }
} }
...@@ -134,18 +146,8 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -134,18 +146,8 @@ void ChannelImpl::dispatch_default_cpu(
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(info); outputs->push_back(info);
} }
if (m_enable_evict & DROP) { if (m_enable_evict & DROP) {
for (auto out : output_infos) { TensorInfo::ComputePath::make(op, input_infos, output_infos);
out->path.op = op;
for (auto out_ : output_infos) {
out->path.outputs.push_back(m_st.at(out_));
}
for (auto inp : input_infos) {
out->path.inputs.push_back(m_st.at(inp));
inp->path.dep_outputs.push_back(m_st.at(out));
}
}
} }
} }
...@@ -168,21 +170,11 @@ void ChannelImpl::dispatch_kernel( ...@@ -168,21 +170,11 @@ void ChannelImpl::dispatch_kernel(
info->h_value = HostTensorND::make_proxy(desc.value) info->h_value = HostTensorND::make_proxy(desc.value)
.proxy_to_comp_node(desc.comp_node); .proxy_to_comp_node(desc.comp_node);
} }
m_valid_handle.insert(info);
cmd.outputs.push_back(info); cmd.outputs.push_back(info);
outputs->push_back(info); outputs->push_back(info);
} }
if (m_enable_evict & DROP) { if (m_enable_evict & DROP) {
for (auto out : cmd.outputs) { TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
out->path.op = cmd.op;
for (auto out_ : cmd.outputs) {
out->path.outputs.push_back(m_st.at(out_));
}
for (auto inp : cmd.inputs) {
out->path.inputs.push_back(m_st.at(inp));
inp->path.dep_outputs.push_back(m_st.at(out));
}
}
} }
m_buffer.enqueue(std::move(cmd)); m_buffer.enqueue(std::move(cmd));
if (!validated && m_async_level == 1) { if (!validated && m_async_level == 1) {
...@@ -215,6 +207,7 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -215,6 +207,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info); input_infos.push_back(info);
input_descs.push_back(info->desc); input_descs.push_back(info->desc);
regenerate(info);
} }
} }
...@@ -233,23 +226,31 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -233,23 +226,31 @@ SmallVector<Handle> ChannelImpl::apply_op(
} }
HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
// TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
if (!info->value_fetched) { // donnot use info->value_fetched, it's unsafe
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
TensorPtr tensor_ptr = info->ptr;
auto value_fetched = [&]() {
return tensor_ptr && tensor_ptr->value_fetched();
};
if (!value_fetched()) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
m_waitee = info; m_waitee = info;
regenerate(info);
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return info->value_fetched; // get tensor ptr in lock to ensure safety
tensor_ptr = info->ptr;
return value_fetched();
}); });
m_waitee = nullptr; m_waitee = nullptr;
} }
mgb_assert(info->ptr->value_fetched()); return tensor_ptr->get_value();
return info->ptr->get_value();
} }
TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
...@@ -298,6 +299,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { ...@@ -298,6 +299,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex); std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
m_waitee = info; m_waitee = info;
regenerate(info);
m_buffer.enqueue(Flush{info}); m_buffer.enqueue(Flush{info});
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
...@@ -332,17 +334,12 @@ int ChannelImpl::get_async_level() { ...@@ -332,17 +334,12 @@ int ChannelImpl::get_async_level() {
TensorInfo* ChannelImpl::alloc() { TensorInfo* ChannelImpl::alloc() {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
auto info = m_pool.alloc(); auto info = m_pool.alloc();
m_st.insert(info); m_valid_handle.insert(info);
return info; return info;
} }
void ChannelImpl::free(TensorInfo* ptr) { void ChannelImpl::free(TensorInfo* ptr) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
if (ptr->path.dep_outputs.size() > 0) {
remove_dep(ptr);
}
m_st.erase(ptr);
mgb_assert(ptr->allow_delete, "delete before ref_cnt = 0");
m_pool.free(ptr); m_pool.free(ptr);
} }
...@@ -350,77 +347,64 @@ ChannelImpl::~ChannelImpl() { ...@@ -350,77 +347,64 @@ ChannelImpl::~ChannelImpl() {
close(); close();
} }
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) MGB_LOCK_GUARD(m_mutex);
: std::unique_lock<std::mutex>();
dest->value_fetched = ptr->value_fetched(); dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer // update tensor desc for static infer
dest->desc.layout = ptr->layout(); dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node(); dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr); dest->ptr = std::move(ptr);
if (notice && m_waitee == dest) { if (m_waitee == dest) {
m_cv.notify_all(); m_cv.notify_all();
} }
} }
void ChannelImpl::do_swap_out(TensorInfo* dest) { void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == DROP) { if (dest->evict_type == DROP) {
mgb_log_warn("the evict type of tensor %p was set to DROP, this SWAP operation will be ignored", dest); recompute(dest->producer);
return; } else if (dest->evict_type == SWAP) {
} swap_in(dest);
if (!dest->ptr) {
return;
} }
dest->evict_type = SWAP; mgb_assert(dest->evict_type == NONE);
dest->value_fetched = false;
// TODO: swap in parallel
dest->h_value = dest->ptr->get_value();
dest->ptr.reset();
} }
void ChannelImpl::do_swap_in(TensorInfo* dest) { void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
if (dest->ptr) { SmallVector<TensorInfo*> workspaces(path->outputs.size(), nullptr);
return; for (auto&& input: path->inputs) {
} regenerate(input);
if (dest->h_value.empty()) {
mgb_log_error("backup of the tensor %p not found", dest);
return;
} }
produce_tensor(dest, Tensor::make(dest->h_value), false); for (auto&& output: path->outputs) {
dest->evict_type = NONE; if(output == nullptr) {
} continue;
void ChannelImpl::remove_dep(TensorInfo* dest) {
for (auto i : dest->path.dep_outputs) {
auto out_ptr = i.lock();
if (out_ptr) {
regenerate(out_ptr.get(), true);
} }
output->evict_type = NONE;
} }
m_buffer.enqueue(ApplyOp{path->op, path->inputs, path->outputs});
} }
void ChannelImpl::do_drop(TensorInfo* dest) { void ChannelImpl::detach_users(TensorInfo* dest) {
if (dest->evict_type == SWAP) { SmallVector<TensorInfo::ComputePath*> users = dest->users;
mgb_log_warn("the evict type of tensor %p was set to SWAP, this DROP operation will be ignored", dest); for (auto* user: users) {
return; for (auto* output: user->outputs) {
} if (output == nullptr) {
if (!dest->path.op) { continue;
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", dest); }
return; regenerate(output);
} output->detach_producer();
if (dest->recompute_times >= m_max_recompute_time) { }
mgb_log_warn("the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored", dest);
return;
}
if (!dest->ptr) {
return;
} }
dest->evict_type = DROP; dest->users.clear();
dest->value_fetched = false;
dest->ptr.reset();
} }
void ChannelImpl::set_swap_flag(bool flag) { void ChannelImpl::set_swap_flag(bool flag) {
if ((!flag) && (m_enable_evict & SWAP)) {
for (auto handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (info->evict_type == SWAP) {
swap_in(info);
}
}
}
if (flag) { if (flag) {
m_enable_evict |= SWAP; m_enable_evict |= SWAP;
} else { } else {
...@@ -429,6 +413,14 @@ void ChannelImpl::set_swap_flag(bool flag) { ...@@ -429,6 +413,14 @@ void ChannelImpl::set_swap_flag(bool flag) {
} }
void ChannelImpl::set_drop_flag(bool flag) { void ChannelImpl::set_drop_flag(bool flag) {
if ((!flag) && (m_enable_evict & DROP)) {
for (auto handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
if (info->evict_type == DROP) {
recompute(info->producer);
}
}
}
if (flag) { if (flag) {
m_enable_evict |= DROP; m_enable_evict |= DROP;
} else { } else {
...@@ -440,46 +432,6 @@ void ChannelImpl::set_buffer_length(int length) { ...@@ -440,46 +432,6 @@ void ChannelImpl::set_buffer_length(int length) {
m_buffer.set_capacity(length); m_buffer.set_capacity(length);
} }
void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
if (!info->ptr && info->evict_type != NONE) {
if (info->evict_type == SWAP) {
do_swap_in(info);
} else {
mgb_assert(info->evict_type == DROP);
mgb_assert(info->path.op, "recomputation path not found");
auto path = info->path;
SmallVector<TensorPtr> inputs;
inputs.reserve(path.inputs.size());
for (auto i : path.inputs) {
mgb_assert(i, "invalid history input");
if (!i->ptr) {
regenerate(i.get(), must_drop);
}
inputs.push_back(i->ptr);
}
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
for (size_t i = 0; i < outputs.size(); i ++) {
auto out_ptr = path.outputs[i].lock();
if (out_ptr) {
out_ptr->recompute_times ++;
if (!out_ptr->ptr && out_ptr->evict_type == DROP) {
produce_tensor(out_ptr.get(), std::move(outputs[i]), false);
}
}
}
}
}
if (must_drop) {
if (info->path.op) {
info->path.op.reset();
info->path.inputs.clear();
if (info->evict_type == DROP) {
info->evict_type = NONE;
}
}
}
}
void ChannelImpl::process_one_task(Command& cmd) { void ChannelImpl::process_one_task(Command& cmd) {
//TODO: remove std::visit for support osx 10.12 //TODO: remove std::visit for support osx 10.12
std::visit([this](auto& cmd) { std::visit([this](auto& cmd) {
...@@ -493,11 +445,6 @@ void ChannelImpl::process_one_task(Command& cmd) { ...@@ -493,11 +445,6 @@ void ChannelImpl::process_one_task(Command& cmd) {
tensor_inputs.reserve(cmd.inputs.size()); tensor_inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr] // refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
if (m_enable_evict && i->evict_type != NONE) {
if (!i->ptr) {
regenerate(i);
}
}
mgb_assert(i->ptr, "Invalid input tensor ptr!"); mgb_assert(i->ptr, "Invalid input tensor ptr!");
// refcnt ++, owners: [i->ptr, tensor_inputs] // refcnt ++, owners: [i->ptr, tensor_inputs]
tensor_inputs.push_back(i->ptr); tensor_inputs.push_back(i->ptr);
...@@ -515,16 +462,14 @@ void ChannelImpl::process_one_task(Command& cmd) { ...@@ -515,16 +462,14 @@ void ChannelImpl::process_one_task(Command& cmd) {
*cmd.op, std::move(tensor_inputs)); *cmd.op, std::move(tensor_inputs));
mgb_assert(tensor_outputs.size() == cmd.outputs.size()); mgb_assert(tensor_outputs.size() == cmd.outputs.size());
for (size_t i = 0; i < tensor_outputs.size(); ++i) { for (size_t i = 0; i < tensor_outputs.size(); ++i) {
if (cmd.outputs[i] == nullptr) {
continue;
}
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i]));
} }
} else if constexpr (std::is_same_v<T, Del>) { } else if constexpr (std::is_same_v<T, Del>) {
free(cmd.dest); free(cmd.dest);
} else if constexpr (std::is_same_v<T, GetValue>) { } else if constexpr (std::is_same_v<T, GetValue>) {
if (m_enable_evict && cmd.dest->evict_type != NONE) {
if (!cmd.dest->ptr) {
regenerate(cmd.dest);
}
}
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value(); cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
...@@ -533,11 +478,12 @@ void ChannelImpl::process_one_task(Command& cmd) { ...@@ -533,11 +478,12 @@ void ChannelImpl::process_one_task(Command& cmd) {
m_cv.notify_all(); m_cv.notify_all();
} }
} else if constexpr (std::is_same_v<T, SwapIn>) { } else if constexpr (std::is_same_v<T, SwapIn>) {
do_swap_in(cmd.dest); produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
} else if constexpr (std::is_same_v<T, SwapOut>) { } else if constexpr (std::is_same_v<T, SwapOut>) {
do_swap_out(cmd.dest); cmd.dest->h_value = cmd.dest->ptr->get_value();
cmd.dest->ptr.reset();
} else if constexpr (std::is_same_v<T, Drop>) { } else if constexpr (std::is_same_v<T, Drop>) {
do_drop(cmd.dest); cmd.dest->ptr.reset();
} else if constexpr (std::is_same_v<T, Move>) { } else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr); produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src); free(cmd.src);
......
...@@ -38,22 +38,77 @@ using TensorInfoPtr = std::shared_ptr<TensorInfo>; ...@@ -38,22 +38,77 @@ using TensorInfoPtr = std::shared_ptr<TensorInfo>;
struct TensorInfo { struct TensorInfo {
TensorPtr ptr; TensorPtr ptr;
LogicalTensorDesc desc; LogicalTensorDesc desc;
// FIXME: broken by drop
bool value_fetched = false; bool value_fetched = false;
bool invalid = false; bool invalid = false;
bool allow_delete = false;
EvictType evict_type = NONE; EvictType evict_type = NONE;
HostTensorND h_value; HostTensorND h_value;
size_t locked = 0;
// reserved for auto drop
size_t pinned = 0;
size_t recompute_times = 0; size_t recompute_times = 0;
struct ComputePath { struct ComputePath {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<TensorInfoPtr> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<std::weak_ptr<TensorInfo>> outputs; SmallVector<TensorInfo*> unique_inputs;
SmallVector<std::weak_ptr<TensorInfo>> dep_outputs; SmallVector<TensorInfo*> outputs;
} path;
size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
}
static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
auto* path = new TensorInfo::ComputePath();
path->op = op;
path->inputs = inputs;
path->outputs = outputs;
// dedup
SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end());
unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
path->unique_inputs = unique_inputs;
// attach users
for (auto input: unique_inputs) {
input->users.push_back(path);
}
// attach producer
for (auto output: outputs) {
output->producer = path;
}
return path;
}
}* producer = nullptr;
void pin() {
++pinned;
}
void unpin() {
--pinned;
}
void detach_producer() {
if (!producer) {
return;
}
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
mgb_assert(output != producer->outputs.end());
*output = nullptr;
if (producer->ref_cnt() == 0) {
for (auto* input: producer->unique_inputs) {
input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
}
delete producer;
}
producer = nullptr;
}
SmallVector<ComputePath*> users;
}; };
struct Put { struct Put {
...@@ -186,17 +241,16 @@ struct ChannelImpl : Interpreter::Channel { ...@@ -186,17 +241,16 @@ struct ChannelImpl : Interpreter::Channel {
private: private:
TensorInfo* alloc(); TensorInfo* alloc();
void free(TensorInfo*); void free(TensorInfo*);
void remove_dep(TensorInfo*); void detach_users(TensorInfo*);
void process_one_task(Command&); void process_one_task(Command&);
void check_worker_exc_unsafe(); void check_worker_exc_unsafe();
void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice); void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void do_swap_out(TensorInfo* dest);
void do_swap_in(TensorInfo* dest); void regenerate(TensorInfo* dest);
void do_drop(TensorInfo* dest); void recompute(TensorInfo::ComputePath* path);
void regenerate(TensorInfo* dest, bool must_drop);
void dispatch_default_cpu( void dispatch_default_cpu(
std::shared_ptr<OpDef> op, std::shared_ptr<OpDef> op,
...@@ -235,24 +289,6 @@ private: ...@@ -235,24 +289,6 @@ private:
ChannelImpl* m_owner; ChannelImpl* m_owner;
} m_worker; } m_worker;
struct SharedTensorInfoMap {
void insert(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}});
}
void erase(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
tmap.erase(info);
}
TensorInfoPtr at(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
return tmap.at(info);
}
private:
std::mutex mtx;
std::unordered_map<TensorInfo*, TensorInfoPtr> tmap;
}m_st;
/** /**
* Buf a command window for following fuse * Buf a command window for following fuse
* example: * example:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册