提交 88b402ef 编写于 作者: M Megvii Engine Team

feat(mge/trace): tracing use id to set priority

GitOrigin-RevId: 6e1f1ece0e0d9a557c6e561f7e30cc6a33fe5152
上级 d2910f7e
...@@ -16,6 +16,7 @@ from typing import Dict, List, Union ...@@ -16,6 +16,7 @@ from typing import Dict, List, Union
import numpy as np import numpy as np
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id
from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.ops import BackwardGraph from .._imperative_rt.ops import BackwardGraph
...@@ -44,6 +45,9 @@ class Graph(_imperative_rt.ComputingGraph): ...@@ -44,6 +45,9 @@ class Graph(_imperative_rt.ComputingGraph):
cache[obj] = wrapper(obj) cache[obj] = wrapper(obj)
return cache[obj] return cache[obj]
def set_priority_to_id(self, dest_vars):
_set_priority_to_id(_unwrap(dest_vars))
def compile(self, *args): def compile(self, *args):
self._function = super().compile(_unwrap(args)) self._function = super().compile(_unwrap(args))
return self return self
......
...@@ -350,6 +350,7 @@ class trace: ...@@ -350,6 +350,7 @@ class trace:
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
else: else:
lazy_eval_graph.options.graph_opt_level = 2 lazy_eval_graph.options.graph_opt_level = 2
lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers])
lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
...@@ -484,7 +485,8 @@ class trace: ...@@ -484,7 +485,8 @@ class trace:
# graph.options.graph_opt_level = 0 # graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = [] need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes # links enforce ordering of I/O nodes
links = () in_out_links = ()
io_links = ()
readers = [] readers = []
if self._capture_as_const: if self._capture_as_const:
...@@ -499,7 +501,7 @@ class trace: ...@@ -499,7 +501,7 @@ class trace:
) )
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
info.varnode = opnode.outputs[0] info.varnode = opnode.outputs[0]
links += opnode.outputs[1:] in_out_links += opnode.outputs[1:]
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
if isinstance(op, Const): if isinstance(op, Const):
...@@ -536,7 +538,7 @@ class trace: ...@@ -536,7 +538,7 @@ class trace:
) )
else: else:
opnode = info.data_setter = G.InputNode( opnode = info.data_setter = G.InputNode(
*links, *in_out_links,
device=info.device, device=info.device,
dtype=info.dtype, dtype=info.dtype,
shape=info.shape or (1,), shape=info.shape or (1,),
...@@ -544,45 +546,48 @@ class trace: ...@@ -544,45 +546,48 @@ class trace:
use_static_shape=_input_node_use_static_shape(), use_static_shape=_input_node_use_static_shape(),
) )
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs info.varnode, *in_out_links = opnode.outputs
if require_links and i == 0 and len(links) > 0: if require_links and i == 0 and len(io_links) > 0:
info.varnode = apply(VirtualDep(), info.varnode, *links)[0] info.varnode = apply(
links = (info.varnode,) VirtualDep(str(io_links[0].device)), info.varnode, *io_links
)[0]
io_links = (info.varnode,)
ivars.append(info.varnode) ivars.append(info.varnode)
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
if require_links and len(ovars) > 0: if require_links and len(ovars) > 0:
links = (ovars[0],) io_links = (ovars[0],)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars): for h, v in zip(ohandles, ovars):
info = self._tinfo[h] info = self._tinfo[h]
info.varnode = v info.varnode = v
def add_reader(opnode): def add_reader(opnode):
nonlocal links nonlocal in_out_links
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
readers.append(opnode.outputs[0]) readers.append(opnode.outputs[0])
links = opnode.outputs in_out_links = opnode.outputs
if info.data_read: if info.data_read:
# Shape can be obtained from data so doesn't need its own # Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately # output node. On the other hand, value is read separately
# to leverage eager h2d copy # to leverage eager h2d copy
info.shape_read = False info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *links) opnode = info.data_reader = G.OutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
if info.value_read: if info.value_read:
opnode = info.value_reader = G.ValueOutputNode(v, *links) opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
if info.shape_read: if info.shape_read:
opnode = info.shape_reader = G.AttrOutputNode(v, *links) opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
# FIXME # FIXME
if self._graph_opt_level is not None: if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level graph.options.graph_opt_level = self._graph_opt_level
else: else:
graph.options.graph_opt_level = 2 graph.options.graph_opt_level = 2
graph.compile(*readers, *links) graph.set_priority_to_id([*readers, *in_out_links, *io_links])
graph.compile(*readers, *in_out_links, *io_links)
def _reset_exec_env(self): def _reset_exec_env(self):
for opnode in self._need_reset_nodes: for opnode in self._need_reset_nodes:
...@@ -1107,7 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ...@@ -1107,7 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
if require_links and active_trace._lazy_eval_links: if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input" assert len(ivars) > 0, "op should has at least one input"
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0] ivars[0] = apply(
VirtualDep(str(active_trace._lazy_eval_links[0].device)),
ivars[0],
*active_trace._lazy_eval_links,
)[0]
active_trace._lazy_eval_links = (ivars[0],) active_trace._lazy_eval_links = (ivars[0],)
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
......
...@@ -246,7 +246,7 @@ class Profiler: ...@@ -246,7 +246,7 @@ class Profiler:
value_type = type(value) value_type = type(value)
if value_type in cls._type_map: if value_type in cls._type_map:
value = cls._type_map[value_type](value) value = cls._type_map[value_type](value)
results[attr] = value results[attr] = str(value)
return results return results
def dump(self, path: Optional[str] = None): def dump(self, path: Optional[str] = None):
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#include "./ops.h" #include "./ops.h"
#include <string>
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
...@@ -45,7 +46,8 @@ void init_ops(py::module m) { ...@@ -45,7 +46,8 @@ void init_ops(py::module m) {
}); });
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep")
.def(py::init<>()); .def(py::init<>())
.def(py::init<std::string>());
#include "opdef.py.inl" #include "opdef.py.inl"
} }
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
*/ */
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include <string>
#include "megbrain/comp_node.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
#include "../op_trait.h" #include "../op_trait.h"
...@@ -20,9 +22,12 @@ namespace { ...@@ -20,9 +22,12 @@ namespace {
cg::OperatorNodeBase* virtual_dep_apply_on_var_node( cg::OperatorNodeBase* virtual_dep_apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) { const OpDef& def, const VarNodeArray& inputs) {
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();
auto&& op = def.cast_final_safe<VirtualDep>();
VarNodeArray inps(inputs.begin(), inputs.end()); VarNodeArray inps(inputs.begin(), inputs.end());
cg::OperatorNodeConfig config; cg::OperatorNodeConfig config;
if (op.device.length() > 0) {
config.comp_node(CompNode::load(op.device));
}
cg::OperatorNodeBase* opr = cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>( graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>(
inps, config)); inps, config));
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#pragma once #pragma once
#include <string>
#include "megbrain/graph/operator_node.h"
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h" #include "megbrain/utils/hash.h"
...@@ -22,6 +24,9 @@ class VirtualDep : public OpDefImplBase<VirtualDep> { ...@@ -22,6 +24,9 @@ class VirtualDep : public OpDefImplBase<VirtualDep> {
public: public:
VirtualDep() = default; VirtualDep() = default;
VirtualDep(std::string dev) : device(dev) {}
std::string device;
size_t hash() const override { size_t hash() const override {
return reinterpret_cast<size_t>(dyn_typeinfo()); return reinterpret_cast<size_t>(dyn_typeinfo());
......
...@@ -206,15 +206,17 @@ SymbolVar Timestamp::make(SymbolVar node, std::shared_ptr<HostTensorND> dest, ...@@ -206,15 +206,17 @@ SymbolVar Timestamp::make(SymbolVar node, std::shared_ptr<HostTensorND> dest,
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep);
VirtualDep::VirtualDep(const VarNodeArray& inputs, VirtualDep::VirtualDep(const VarNodeArray& inputs,
const OperatorNodeConfig& config) const OperatorNodeConfig& cfg)
: Super(inputs[0]->owner_graph(), : Super(inputs[0]->owner_graph(),
setup_config_cn(config, inputs[0]->comp_node()), "virtual_dep", cfg.has_comp_node_set() ? cfg : setup_config_cn(cfg, inputs[0]->comp_node()),
inputs) { "virtual_dep", inputs) {
for (auto inp : inputs) { for (auto inp : inputs) {
add_input({inp}); add_input({inp});
} }
mgb_assert(inputs[0]->dtype().valid()); mgb_assert(inputs[0]->dtype().valid());
add_output(None)->dtype(inputs[0]->dtype()); add_output(None)
->dtype(inputs[0]->dtype())
.comp_node(config().get_single_comp_node());
} }
cg::OperatorNodeBase::NodeProp* VirtualDep::do_make_node_prop() const { cg::OperatorNodeBase::NodeProp* VirtualDep::do_make_node_prop() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册