提交 989fdde2 编写于 作者: M Megvii Engine Team

refactor(subgraph): use graph queue to cache compiled op graphs

GitOrigin-RevId: cba8574c73679bb7928e3401446980812b33461d
上级 a7a3bf2d
......@@ -9,6 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <queue>
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
......@@ -277,23 +279,50 @@ struct ComputingGraphHolder {
SmallVector<std::shared_ptr<DeviceTensorND>> inputs;
SmallVector<std::shared_ptr<DeviceTensorND>> outputs;
std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
SmallVector<std::unique_ptr<CompNode::Event>> events;
};
thread_local OpMethResultCache<ComputingGraphHolder> cg_cache;
ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) {
OpMethArgs<> key = {compiled_op, descs};
auto& cg_holder = cg_cache[key];
if (!cg_holder.graph) {
using ComputingGraphHolderCache = OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder>>>;
thread_local ComputingGraphHolderCache cache;
thread_local size_t nr_cg_holders = 0;
ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
auto& cg_holder_queue = cache[cache_key];
std::unique_ptr<ComputingGraphHolder> holder;
if(!cg_holder_queue.empty()) {
// pick one
std::swap(cg_holder_queue.front(), holder);
// check all events finished
for (auto&& event: holder->events) {
if (!event->finished()) {
bool queue_limited = event->comp_node().contain_flag(CompNode::Flag::QUEUE_LIMITED);
bool many_graph = cg_holder_queue.size() > 10;
if (queue_limited || !many_graph) {
std::swap(cg_holder_queue.front(), holder);
break;
} else {
// graph limit
mgb_log_debug("computing graph limit for compiled op exceeded, waiting for prev graph");
event->host_wait();
}
}
}
if (holder) {
cg_holder_queue.pop();
}
}
if (!holder) {
// create new computing graph
holder = std::make_unique<ComputingGraphHolder>();
auto& cg_holder = *holder;
cg_holder.allocator = std::make_shared<DeviceMemoryAllocatorImpl>();
cg_holder.graph = ComputingGraph::make();
cg_holder.graph->options().force_dynamic_alloc = true;
cg_holder.graph->options().async_exec_level = 0;
cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level;
cg_holder.graph->options().enable_var_mem_defragment = false;
cg_holder.graph->options().comp_seq_sync_device = false;
// set allocator for DTR support
cg_holder.graph->set_device_memory_allocator(cg_holder.allocator);
// cg_holder.graph->options().graph_opt.jit = 2;
VarNodeArray input_vars;
for (auto&& desc: descs) {
auto input_device_nd = std::make_shared<DeviceTensorND>();
......@@ -321,8 +350,19 @@ ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, Sm
cg_holder.outputs.push_back(output_ptr);
}
cg_holder.executable = cg_holder.graph->compile(output_spec);
CompNode::UnorderedSet comp_nodes;
for (auto&& output_var: output_vars) {
comp_nodes.insert(output_var->comp_node());
}
for (auto&& comp_node: comp_nodes) {
cg_holder.events.push_back(comp_node.create_event());
cg_holder.events.back()->record();
}
nr_cg_holders++;
mgb_log_debug("add new computing graph for compiled op, now %zu graphs", nr_cg_holders);
}
return cg_holder;
cg_holder_queue.push(std::move(holder));
return *cg_holder_queue.back();
}
auto apply_on_physical_tensor(
......@@ -335,13 +375,17 @@ auto apply_on_physical_tensor(
size_t nr_inputs = inputs.size();
auto shared_def = const_cast<OpDef&>(def).shared_from_this();
auto& cg_holder = get_computing_graph(shared_def, input_descs);
// wait for last execution
cg_holder.executable->wait();
for (size_t i = 0; i < nr_inputs; ++i) {
auto input_dev_tensor = inputs[i]->dev_tensor();
cg_holder.inputs[i]->reset(input_dev_tensor.storage(), input_dev_tensor.layout());
}
cg_holder.allocator->current_op = shared_def;
cg_holder.executable->execute();
cg_holder.executable->wait();
for (auto&& event: cg_holder.events) {
event->record();
}
SmallVector<TensorPtr> outputs;
for (auto input_nd: cg_holder.inputs) {
*input_nd = {};
......
......@@ -18,6 +18,7 @@
namespace mgb {
namespace imperative {
//NOTE: only input dtype and comp_node used for hashing, shapes are ignored
template <typename... TExtraArgs>
struct OpMethArgs {
std::shared_ptr<OpDef> op;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册