diff --git a/imperative/src/impl/interpreter/commands.h b/imperative/src/impl/interpreter/commands.h index d4f5e48b1387934e552b6c769aed94b73c15108d..54763b16c3a7820951c95eb4c332fa51ffc9b737 100644 --- a/imperative/src/impl/interpreter/commands.h +++ b/imperative/src/impl/interpreter/commands.h @@ -49,7 +49,6 @@ struct ApplyOp { std::shared_ptr op; SmallVector inputs; SmallVector outputs; - SmallVector outputs_descs; bool validated = false; template diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 99f69a58c77f26de4ef0987c565a1b73f1a07f9d..5bf55b63ee75cfb829505043d1a94779ec7cb40f 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -355,7 +355,7 @@ void ChannelImpl::dispatch_kernel( for (int i = 0; i < output_descs.size(); ++i) { auto&& desc = output_descs[i]; auto info = alloc(); - init(info, desc); + init(info, std::move(desc)); // make sure desc's value is consistent with h_value if (!info->desc.value.empty()) { info->h_value = HostTensorND::make_proxy(desc.value) @@ -364,9 +364,9 @@ void ChannelImpl::dispatch_kernel( output_infos.push_back(info); outputs->push_back(reinterpret_cast(info)); } - ApplyOp cmd{Profiler::next_id(), std::move(op), - std::move(input_infos), std::move(output_infos), - std::move(output_descs), validated}; + ApplyOp cmd{ + Profiler::next_id(), std::move(op), std::move(input_infos), + std::move(output_infos), validated}; if (Profiler::is_profiling()) { auto op_info_getter = [op = cmd.op] { std::unordered_map op_info; @@ -594,7 +594,7 @@ TensorInfo* ChannelImpl::alloc() { return info; } -void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { +void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) { m_valid_handle.insert(reinterpret_cast(info)); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); info->status = TensorInfo::Allocated; @@ -724,9 +724,8 @@ void ChannelImpl::regenerate(TensorInfo* dest) { if (dest->evict_type == EvictType::DROP) { auto&& path = dest->producer; m_apply_stack.push( - {ApplyOp{path->id, path->op, path->inputs, path->outputs, - path->outputs_descs}, - 0, dest, "dtr"}); + {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest, + "dtr"}); if (!m_applying) flush_apply_stack(); } @@ -819,13 +818,18 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { } // Apply op SmallVector output_descs; - for (auto i : cmd.outputs_descs) { - output_descs.push_back(i); + bool validated = cmd.validated; + if (!state.options.enable_dtr_auto_drop) { + for (auto i : cmd.outputs) { + output_descs.push_back(i->desc); + } + } else { + validated = false; } // Here std::move is REQUIRED for removing duplicated references. auto outputs = apply_on_physical_tensor( apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs, - cmd.validated); + validated); // After execute for (auto&& [device, kernel_id] : kernels) { MGB_RECORD_EVENT_IF( @@ -1154,7 +1158,7 @@ void ChannelImpl::process_one_task(Command& icmd) { if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) { TensorInfo::ComputePath::make( - cmd.id, cmd.op, cmd.inputs, cmd.outputs, cmd.outputs_descs); + cmd.id, cmd.op, cmd.inputs, cmd.outputs); size_t detach_cnt = 0; if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 6) { diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 9dc3914448533c207b1ecf3106ee035129361eeb..970c2b2cb186f0b401acf2d6c99410f9c334e5ac 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -77,7 +77,7 @@ private: struct State; TensorInfo* alloc(); - void init(TensorInfo*, LogicalTensorDesc desc); + void init(TensorInfo*, LogicalTensorDesc&& desc); void free(TensorInfo*); void real_free(TensorInfo*); void recursive_free(TensorInfo*); diff --git a/imperative/src/impl/interpreter/tensor_info.h b/imperative/src/impl/interpreter/tensor_info.h index 0485689c78fc17e27d68703435e0a81fd8944d36..c0c8223bd0ba3355d461484095f336c2573c7ec3 100644 --- a/imperative/src/impl/interpreter/tensor_info.h +++ b/imperative/src/impl/interpreter/tensor_info.h @@ -99,14 +99,12 @@ struct TensorInfo { static ComputePath* make( uint64_t id, std::shared_ptr op, SmallVector inputs, - SmallVector outputs, - SmallVector outputs_descs) { + SmallVector outputs) { auto* path = new TensorInfo::ComputePath(); path->id = id; path->op = op; path->inputs = inputs; path->outputs = outputs; - path->outputs_descs = outputs_descs; // dedup SmallVector unique_inputs = inputs; std::sort(unique_inputs.begin(), unique_inputs.end());