提交 bd62a0a6 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(interpreter): remove notice flag of produce_tensor

GitOrigin-RevId: ed65d0107f4c95d4874ffb5b2e1991e9f0307d79
上级 c78a7848
...@@ -196,6 +196,10 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -196,6 +196,10 @@ void ChannelImpl::dispatch_default_cpu(
const SmallVector<LogicalTensorDesc>& input_descs, const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) { SmallVector<Handle>* outputs) {
auto& state = get_channel_state(); auto& state = get_channel_state();
auto name = op->trait()->make_name(*op);
state.scopes.push(name);
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
RECORD_EVENT(ShapeInferEvent, validated); RECORD_EVENT(ShapeInferEvent, validated);
...@@ -256,6 +260,8 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -256,6 +260,8 @@ void ChannelImpl::dispatch_default_cpu(
return op_info; return op_info;
}; };
RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos)); RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
state.scopes.pop(name);
} }
void ChannelImpl::dispatch_kernel( void ChannelImpl::dispatch_kernel(
...@@ -353,7 +359,6 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -353,7 +359,6 @@ SmallVector<Handle> ChannelImpl::apply_op(
HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -364,7 +369,6 @@ HostTensorND ChannelImpl::get_value(Handle handle) { ...@@ -364,7 +369,6 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -378,7 +382,6 @@ TensorShape ChannelImpl::get_shape(Handle handle) { ...@@ -378,7 +382,6 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
DType ChannelImpl::get_dtype(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -390,7 +393,6 @@ DType ChannelImpl::get_dtype(Handle handle) { ...@@ -390,7 +393,6 @@ DType ChannelImpl::get_dtype(Handle handle) {
CompNode ChannelImpl::get_device(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -402,7 +404,6 @@ CompNode ChannelImpl::get_device(Handle handle) { ...@@ -402,7 +404,6 @@ CompNode ChannelImpl::get_device(Handle handle) {
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -411,7 +412,6 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { ...@@ -411,7 +412,6 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
void ChannelImpl::sync() { void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
m_buffer.flush(); m_buffer.flush();
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
...@@ -519,7 +519,6 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) { ...@@ -519,7 +519,6 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) {
void ChannelImpl::real_free(TensorInfo* ptr) { void ChannelImpl::real_free(TensorInfo* ptr) {
auto& state = get_worker_state(); auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex);
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(ptr); m_dtr.erase_candidate(ptr);
} }
...@@ -531,6 +530,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) { ...@@ -531,6 +530,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
} }
RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count); RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
ptr->status = TensorInfo::Deleted; ptr->status = TensorInfo::Deleted;
MGB_LOCK_GUARD(m_mutex);
m_pool.free(ptr); m_pool.free(ptr);
} }
...@@ -540,12 +540,9 @@ ChannelImpl::~ChannelImpl() { ...@@ -540,12 +540,9 @@ ChannelImpl::~ChannelImpl() {
close(); close();
} }
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
auto& state = get_worker_state(); auto& state = get_worker_state();
std::unique_lock<std::mutex> lock{m_mutex, std::defer_lock}; MGB_LOCK_GUARD(m_mutex);
if (notice) {
lock.lock();
}
m_dtr.update_used_time(dest); m_dtr.update_used_time(dest);
RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr()); RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
// update tensor desc for static infer // update tensor desc for static infer
...@@ -555,12 +552,10 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr ...@@ -555,12 +552,10 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr
dest->ptr = std::move(ptr); dest->ptr = std::move(ptr);
dest->evict_type = EvictType::NONE; dest->evict_type = EvictType::NONE;
dest->status = TensorInfo::Produced; dest->status = TensorInfo::Produced;
if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.insert_candidate(dest); m_dtr.insert_candidate(dest);
} }
if (notice) {
notify_tensor_unsafe(dest); notify_tensor_unsafe(dest);
}
} }
void ChannelImpl::release_tensor(TensorInfo* dest) { void ChannelImpl::release_tensor(TensorInfo* dest) {
...@@ -781,6 +776,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -781,6 +776,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
} }
if (!value_fetching) { if (!value_fetching) {
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_buffer.flush();
value_fetching = true; value_fetching = true;
} }
return false; return false;
...@@ -789,16 +785,12 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -789,16 +785,12 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
} }
}); });
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr); RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr);
if (m_waitee != nullptr) {
mgb_assert(m_waitee == info, "waitee mismatch");
m_waitee = nullptr; m_waitee = nullptr;
}
return info->ptr; return info->ptr;
} }
void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
if (info == m_waitee) { if (info == m_waitee) {
m_waitee = nullptr;
RECORD_EVENT(TensorNotifyPropEvent, info->id); RECORD_EVENT(TensorNotifyPropEvent, info->id);
m_cv.notify_all(); m_cv.notify_all();
} }
...@@ -809,7 +801,6 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() { ...@@ -809,7 +801,6 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
for (auto* handle: m_valid_handle) { for (auto* handle: m_valid_handle) {
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
valid_tensors.insert(info); valid_tensors.insert(info);
//TODO: valid_tensors.insert({info, info->status});
} }
return valid_tensors; return valid_tensors;
} }
...@@ -1005,7 +996,6 @@ void ChannelImpl::CommandBuffer::flush() { ...@@ -1005,7 +996,6 @@ void ChannelImpl::CommandBuffer::flush() {
} }
void ChannelImpl::CommandBuffer::flush(Handle pos) { void ChannelImpl::CommandBuffer::flush(Handle pos) {
auto& state = m_owner->get_channel_state();
for (auto iter = m_commands.begin(); iter != pos; ++iter) { for (auto iter = m_commands.begin(); iter != pos; ++iter) {
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str()); mgb_log_debug("%s Flushed", to_string(*iter).c_str());
......
...@@ -91,7 +91,7 @@ private: ...@@ -91,7 +91,7 @@ private:
void check_worker_exc_unsafe(); void check_worker_exc_unsafe();
void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice); void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void release_tensor(TensorInfo* dest); void release_tensor(TensorInfo* dest);
......
...@@ -103,6 +103,7 @@ struct MemoryFlow { ...@@ -103,6 +103,7 @@ struct MemoryFlow {
auto addr_begin = std::numeric_limits<uintptr_t>::max(); auto addr_begin = std::numeric_limits<uintptr_t>::max();
auto addr_end = std::numeric_limits<uintptr_t>::min(); auto addr_end = std::numeric_limits<uintptr_t>::min();
for(auto&& [id, chunk]: chunks) { for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
addr_begin = std::min(addr_begin, chunk.address[0]); addr_begin = std::min(addr_begin, chunk.address[0]);
addr_end = std::max(addr_end, chunk.address[1]); addr_end = std::max(addr_end, chunk.address[1]);
...@@ -114,6 +115,7 @@ struct MemoryFlow { ...@@ -114,6 +115,7 @@ struct MemoryFlow {
auto time_begin = std::numeric_limits<uint64_t>::max(); auto time_begin = std::numeric_limits<uint64_t>::max();
auto time_end = std::numeric_limits<uint64_t>::min(); auto time_end = std::numeric_limits<uint64_t>::min();
for(auto&& [id, chunk]: chunks) { for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
time_begin = std::min(time_begin, chunk.time[0]); time_begin = std::min(time_begin, chunk.time[0]);
time_end = std::max(time_end, chunk.time[1]); time_end = std::max(time_end, chunk.time[1]);
...@@ -124,6 +126,7 @@ struct MemoryFlow { ...@@ -124,6 +126,7 @@ struct MemoryFlow {
std::shared_ptr<json::Array> to_json() const { std::shared_ptr<json::Array> to_json() const {
auto results = json::Array::make(); auto results = json::Array::make();
for(auto&& [id, chunk]: chunks) { for(auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
auto address = json::Array::make(); auto address = json::Array::make();
auto time = json::Array::make(); auto time = json::Array::make();
...@@ -213,6 +216,7 @@ struct MemoryFlow { ...@@ -213,6 +216,7 @@ struct MemoryFlow {
return builder; return builder;
}; };
for (auto&& [id, chunk]: chunks) { for (auto&& [id, chunk]: chunks) {
MGB_MARK_USED_VAR(id);
if (chunk.empty()) continue; if (chunk.empty()) continue;
double left = (chunk.time[0]-time_begin)/time_scale; double left = (chunk.time[0]-time_begin)/time_scale;
double right = (chunk.time[1]-time_begin)/time_scale; double right = (chunk.time[1]-time_begin)/time_scale;
......
...@@ -131,6 +131,7 @@ public: ...@@ -131,6 +131,7 @@ public:
MGB_LOCK_GUARD(sm_mutex); MGB_LOCK_GUARD(sm_mutex);
if constexpr (sm_debug) { if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) { for (auto&& [tid, profiler]: sm_profilers) {
MGB_MARK_USED_VAR(tid);
Status expected = Running; Status expected = Running;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Collecting)); mgb_assert(profiler->m_status.compare_exchange_strong(expected, Collecting));
} }
...@@ -149,6 +150,7 @@ public: ...@@ -149,6 +150,7 @@ public:
}); });
if constexpr (sm_debug) { if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) { for (auto&& [tid, profiler]: sm_profilers) {
MGB_MARK_USED_VAR(tid);
Status expected = Collecting; Status expected = Collecting;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Running)); mgb_assert(profiler->m_status.compare_exchange_strong(expected, Running));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册