提交 88b402ef 编写于 作者: M Megvii Engine Team

feat(mge/trace): tracing use id to set priority

GitOrigin-RevId: 6e1f1ece0e0d9a557c6e561f7e30cc6a33fe5152
上级 d2910f7e
......@@ -16,6 +16,7 @@ from typing import Dict, List, Union
import numpy as np
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.ops import BackwardGraph
......@@ -44,6 +45,9 @@ class Graph(_imperative_rt.ComputingGraph):
cache[obj] = wrapper(obj)
return cache[obj]
def set_priority_to_id(self, dest_vars):
_set_priority_to_id(_unwrap(dest_vars))
def compile(self, *args):
self._function = super().compile(_unwrap(args))
return self
......
......@@ -350,6 +350,7 @@ class trace:
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
else:
lazy_eval_graph.options.graph_opt_level = 2
lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers])
lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
......@@ -484,7 +485,8 @@ class trace:
# graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes
links = ()
in_out_links = ()
io_links = ()
readers = []
if self._capture_as_const:
......@@ -499,7 +501,7 @@ class trace:
)
need_reset_nodes.append(opnode)
info.varnode = opnode.outputs[0]
links += opnode.outputs[1:]
in_out_links += opnode.outputs[1:]
for op, ihandles, ohandles in self._seq:
if isinstance(op, Const):
......@@ -536,7 +538,7 @@ class trace:
)
else:
opnode = info.data_setter = G.InputNode(
*links,
*in_out_links,
device=info.device,
dtype=info.dtype,
shape=info.shape or (1,),
......@@ -544,45 +546,48 @@ class trace:
use_static_shape=_input_node_use_static_shape(),
)
need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs
if require_links and i == 0 and len(links) > 0:
info.varnode = apply(VirtualDep(), info.varnode, *links)[0]
links = (info.varnode,)
info.varnode, *in_out_links = opnode.outputs
if require_links and i == 0 and len(io_links) > 0:
info.varnode = apply(
VirtualDep(str(io_links[0].device)), info.varnode, *io_links
)[0]
io_links = (info.varnode,)
ivars.append(info.varnode)
ovars = apply(op, *ivars)
if require_links and len(ovars) > 0:
links = (ovars[0],)
io_links = (ovars[0],)
assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars):
info = self._tinfo[h]
info.varnode = v
def add_reader(opnode):
nonlocal links
nonlocal in_out_links
need_reset_nodes.append(opnode)
readers.append(opnode.outputs[0])
links = opnode.outputs
in_out_links = opnode.outputs
if info.data_read:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *links)
opnode = info.data_reader = G.OutputNode(v, *in_out_links)
add_reader(opnode)
if info.value_read:
opnode = info.value_reader = G.ValueOutputNode(v, *links)
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
add_reader(opnode)
if info.shape_read:
opnode = info.shape_reader = G.AttrOutputNode(v, *links)
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode)
# FIXME
if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level
else:
graph.options.graph_opt_level = 2
graph.compile(*readers, *links)
graph.set_priority_to_id([*readers, *in_out_links, *io_links])
graph.compile(*readers, *in_out_links, *io_links)
def _reset_exec_env(self):
for opnode in self._need_reset_nodes:
......@@ -1107,7 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input"
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0]
ivars[0] = apply(
VirtualDep(str(active_trace._lazy_eval_links[0].device)),
ivars[0],
*active_trace._lazy_eval_links,
)[0]
active_trace._lazy_eval_links = (ivars[0],)
ovars = apply(op, *ivars)
......
......@@ -246,7 +246,7 @@ class Profiler:
value_type = type(value)
if value_type in cls._type_map:
value = cls._type_map[value_type](value)
results[attr] = value
results[attr] = str(value)
return results
def dump(self, path: Optional[str] = None):
......
......@@ -10,6 +10,7 @@
*/
#include "./ops.h"
#include <string>
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
......@@ -45,7 +46,8 @@ void init_ops(py::module m) {
});
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep")
.def(py::init<>());
.def(py::init<>())
.def(py::init<std::string>());
#include "opdef.py.inl"
}
......@@ -10,6 +10,8 @@
*/
#include "megbrain/imperative/ops/utility.h"
#include <string>
#include "megbrain/comp_node.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
#include "../op_trait.h"
......@@ -20,9 +22,12 @@ namespace {
cg::OperatorNodeBase* virtual_dep_apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& graph = inputs[0]->owner_graph();
auto&& op = def.cast_final_safe<VirtualDep>();
VarNodeArray inps(inputs.begin(), inputs.end());
cg::OperatorNodeConfig config;
if (op.device.length() > 0) {
config.comp_node(CompNode::load(op.device));
}
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>(
inps, config));
......
......@@ -11,6 +11,8 @@
#pragma once
#include <string>
#include "megbrain/graph/operator_node.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
......@@ -22,6 +24,9 @@ class VirtualDep : public OpDefImplBase<VirtualDep> {
public:
VirtualDep() = default;
VirtualDep(std::string dev) : device(dev) {}
std::string device;
size_t hash() const override {
return reinterpret_cast<size_t>(dyn_typeinfo());
......
......@@ -206,15 +206,17 @@ SymbolVar Timestamp::make(SymbolVar node, std::shared_ptr<HostTensorND> dest,
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep);
VirtualDep::VirtualDep(const VarNodeArray& inputs,
const OperatorNodeConfig& config)
const OperatorNodeConfig& cfg)
: Super(inputs[0]->owner_graph(),
setup_config_cn(config, inputs[0]->comp_node()), "virtual_dep",
inputs) {
cfg.has_comp_node_set() ? cfg : setup_config_cn(cfg, inputs[0]->comp_node()),
"virtual_dep", inputs) {
for (auto inp : inputs) {
add_input({inp});
}
mgb_assert(inputs[0]->dtype().valid());
add_output(None)->dtype(inputs[0]->dtype());
add_output(None)
->dtype(inputs[0]->dtype())
.comp_node(config().get_single_comp_node());
}
cg::OperatorNodeBase::NodeProp* VirtualDep::do_make_node_prop() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册