提交 bd62a0a6 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(interpreter): remove notice flag of produce_tensor

GitOrigin-RevId: ed65d0107f4c95d4874ffb5b2e1991e9f0307d79
上级 c78a7848
......@@ -196,6 +196,10 @@ void ChannelImpl::dispatch_default_cpu(
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) {
auto& state = get_channel_state();
auto name = op->trait()->make_name(*op);
state.scopes.push(name);
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
RECORD_EVENT(ShapeInferEvent, validated);
......@@ -256,6 +260,8 @@ void ChannelImpl::dispatch_default_cpu(
return op_info;
};
RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
state.scopes.pop(name);
}
void ChannelImpl::dispatch_kernel(
......@@ -353,7 +359,6 @@ SmallVector<Handle> ChannelImpl::apply_op(
HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -364,7 +369,6 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -378,7 +382,6 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -390,7 +393,6 @@ DType ChannelImpl::get_dtype(Handle handle) {
CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -402,7 +404,6 @@ CompNode ChannelImpl::get_device(Handle handle) {
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -411,7 +412,6 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
m_buffer.flush();
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
......@@ -519,7 +519,6 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) {
void ChannelImpl::real_free(TensorInfo* ptr) {
auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex);
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(ptr);
}
......@@ -531,6 +530,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
}
RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
ptr->status = TensorInfo::Deleted;
MGB_LOCK_GUARD(m_mutex);
m_pool.free(ptr);
}
......@@ -540,12 +540,9 @@ ChannelImpl::~ChannelImpl() {
close();
}
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto& state = get_worker_state();
std::unique_lock<std::mutex> lock{m_mutex, std::defer_lock};
if (notice) {
lock.lock();
}
MGB_LOCK_GUARD(m_mutex);
m_dtr.update_used_time(dest);
RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
// update tensor desc for static infer
......@@ -555,12 +552,10 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr
dest->ptr = std::move(ptr);
dest->evict_type = EvictType::NONE;
dest->status = TensorInfo::Produced;
if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.insert_candidate(dest);
}
if (notice) {
notify_tensor_unsafe(dest);
}
notify_tensor_unsafe(dest);
}
void ChannelImpl::release_tensor(TensorInfo* dest) {
......@@ -781,6 +776,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}
if (!value_fetching) {
m_buffer.enqueue(GetValue{info});
m_buffer.flush();
value_fetching = true;
}
return false;
......@@ -789,16 +785,12 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}
});
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr);
if (m_waitee != nullptr) {
mgb_assert(m_waitee == info, "waitee mismatch");
m_waitee = nullptr;
}
m_waitee = nullptr;
return info->ptr;
}
void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if (info == m_waitee) {
m_waitee = nullptr;
RECORD_EVENT(TensorNotifyPropEvent, info->id);
m_cv.notify_all();
}
......@@ -809,7 +801,6 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
for (auto* handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle);
valid_tensors.insert(info);
//TODO: valid_tensors.insert({info, info->status});
}
return valid_tensors;
}
......@@ -1005,7 +996,6 @@ void ChannelImpl::CommandBuffer::flush() {
}
void ChannelImpl::CommandBuffer::flush(Handle pos) {
auto& state = m_owner->get_channel_state();
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
if (Profiler::is_profiling()) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
......
......@@ -91,7 +91,7 @@ private:
void check_worker_exc_unsafe();
void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void release_tensor(TensorInfo* dest);
......
......@@ -103,6 +103,7 @@ struct MemoryFlow {
auto addr_begin = std::numeric_limits<uintptr_t>::max();
auto addr_end = std::numeric_limits<uintptr_t>::min();
for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue;
addr_begin = std::min(addr_begin, chunk.address[0]);
addr_end = std::max(addr_end, chunk.address[1]);
......@@ -114,6 +115,7 @@ struct MemoryFlow {
auto time_begin = std::numeric_limits<uint64_t>::max();
auto time_end = std::numeric_limits<uint64_t>::min();
for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue;
time_begin = std::min(time_begin, chunk.time[0]);
time_end = std::max(time_end, chunk.time[1]);
......@@ -124,6 +126,7 @@ struct MemoryFlow {
std::shared_ptr<json::Array> to_json() const {
auto results = json::Array::make();
for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue;
auto address = json::Array::make();
auto time = json::Array::make();
......@@ -213,6 +216,7 @@ struct MemoryFlow {
return builder;
};
for (auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue;
double left = (chunk.time[0]-time_begin)/time_scale;
double right = (chunk.time[1]-time_begin)/time_scale;
......
......@@ -131,6 +131,7 @@ public:
MGB_LOCK_GUARD(sm_mutex);
if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) {
MGB_MARK_USED_VAR(tid);
Status expected = Running;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Collecting));
}
......@@ -149,6 +150,7 @@ public:
});
if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) {
MGB_MARK_USED_VAR(tid);
Status expected = Collecting;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Running));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册