提交 9ffc2c0a 编写于 作者: M Megvii Engine Team

fix(mge): fix host performance loss caused by dtr

GitOrigin-RevId: ee8b729e8087cb42e904fb33f59043b73b5d2262
上级 69673f14
...@@ -49,7 +49,6 @@ struct ApplyOp { ...@@ -49,7 +49,6 @@ struct ApplyOp {
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs; SmallVector<TensorInfo*> outputs;
SmallVector<LogicalTensorDesc> outputs_descs;
bool validated = false; bool validated = false;
template <typename TFunctor> template <typename TFunctor>
......
...@@ -355,7 +355,7 @@ void ChannelImpl::dispatch_kernel( ...@@ -355,7 +355,7 @@ void ChannelImpl::dispatch_kernel(
for (int i = 0; i < output_descs.size(); ++i) { for (int i = 0; i < output_descs.size(); ++i) {
auto&& desc = output_descs[i]; auto&& desc = output_descs[i];
auto info = alloc(); auto info = alloc();
init(info, desc); init(info, std::move(desc));
// make sure desc's value is consistent with h_value // make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) { if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value) info->h_value = HostTensorND::make_proxy(desc.value)
...@@ -364,9 +364,9 @@ void ChannelImpl::dispatch_kernel( ...@@ -364,9 +364,9 @@ void ChannelImpl::dispatch_kernel(
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info)); outputs->push_back(reinterpret_cast<Handle>(info));
} }
ApplyOp cmd{Profiler::next_id(), std::move(op), ApplyOp cmd{
std::move(input_infos), std::move(output_infos), Profiler::next_id(), std::move(op), std::move(input_infos),
std::move(output_descs), validated}; std::move(output_infos), validated};
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
auto op_info_getter = [op = cmd.op] { auto op_info_getter = [op = cmd.op] {
std::unordered_map<std::string, std::string> op_info; std::unordered_map<std::string, std::string> op_info;
...@@ -594,7 +594,7 @@ TensorInfo* ChannelImpl::alloc() { ...@@ -594,7 +594,7 @@ TensorInfo* ChannelImpl::alloc() {
return info; return info;
} }
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
m_valid_handle.insert(reinterpret_cast<Handle>(info)); m_valid_handle.insert(reinterpret_cast<Handle>(info));
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated; info->status = TensorInfo::Allocated;
...@@ -724,9 +724,8 @@ void ChannelImpl::regenerate(TensorInfo* dest) { ...@@ -724,9 +724,8 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == EvictType::DROP) { if (dest->evict_type == EvictType::DROP) {
auto&& path = dest->producer; auto&& path = dest->producer;
m_apply_stack.push( m_apply_stack.push(
{ApplyOp{path->id, path->op, path->inputs, path->outputs, {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
path->outputs_descs}, "dtr"});
0, dest, "dtr"});
if (!m_applying) if (!m_applying)
flush_apply_stack(); flush_apply_stack();
} }
...@@ -819,13 +818,18 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -819,13 +818,18 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
} }
// Apply op // Apply op
SmallVector<LogicalTensorDesc> output_descs; SmallVector<LogicalTensorDesc> output_descs;
for (auto i : cmd.outputs_descs) { bool validated = cmd.validated;
output_descs.push_back(i); if (!state.options.enable_dtr_auto_drop) {
for (auto i : cmd.outputs) {
output_descs.push_back(i->desc);
}
} else {
validated = false;
} }
// Here std::move is REQUIRED for removing duplicated references. // Here std::move is REQUIRED for removing duplicated references.
auto outputs = apply_on_physical_tensor( auto outputs = apply_on_physical_tensor(
apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs, apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
cmd.validated); validated);
// After execute // After execute
for (auto&& [device, kernel_id] : kernels) { for (auto&& [device, kernel_id] : kernels) {
MGB_RECORD_EVENT_IF( MGB_RECORD_EVENT_IF(
...@@ -1154,7 +1158,7 @@ void ChannelImpl::process_one_task(Command& icmd) { ...@@ -1154,7 +1158,7 @@ void ChannelImpl::process_one_task(Command& icmd) {
if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
TensorInfo::ComputePath::make( TensorInfo::ComputePath::make(
cmd.id, cmd.op, cmd.inputs, cmd.outputs, cmd.outputs_descs); cmd.id, cmd.op, cmd.inputs, cmd.outputs);
size_t detach_cnt = 0; size_t detach_cnt = 0;
if (!strcmp(get_name(*cmd.op), "BatchNorm") && if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
cmd.outputs.size() == 6) { cmd.outputs.size() == 6) {
......
...@@ -77,7 +77,7 @@ private: ...@@ -77,7 +77,7 @@ private:
struct State; struct State;
TensorInfo* alloc(); TensorInfo* alloc();
void init(TensorInfo*, LogicalTensorDesc desc); void init(TensorInfo*, LogicalTensorDesc&& desc);
void free(TensorInfo*); void free(TensorInfo*);
void real_free(TensorInfo*); void real_free(TensorInfo*);
void recursive_free(TensorInfo*); void recursive_free(TensorInfo*);
......
...@@ -99,14 +99,12 @@ struct TensorInfo { ...@@ -99,14 +99,12 @@ struct TensorInfo {
static ComputePath* make( static ComputePath* make(
uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs,
SmallVector<TensorInfo*> outputs, SmallVector<TensorInfo*> outputs) {
SmallVector<LogicalTensorDesc> outputs_descs) {
auto* path = new TensorInfo::ComputePath(); auto* path = new TensorInfo::ComputePath();
path->id = id; path->id = id;
path->op = op; path->op = op;
path->inputs = inputs; path->inputs = inputs;
path->outputs = outputs; path->outputs = outputs;
path->outputs_descs = outputs_descs;
// dedup // dedup
SmallVector<TensorInfo*> unique_inputs = inputs; SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end()); std::sort(unique_inputs.begin(), unique_inputs.end());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册