diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index abe5e98c2abe83fa4d82c4fe68daa6fc9ca394e0..485407eed474864aa306372193c141a448791ed2 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -16,6 +16,7 @@ from typing import Dict, List, Union import numpy as np +from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id from .. import _imperative_rt from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt.ops import BackwardGraph @@ -44,6 +45,9 @@ class Graph(_imperative_rt.ComputingGraph): cache[obj] = wrapper(obj) return cache[obj] + def set_priority_to_id(self, dest_vars): + _set_priority_to_id(_unwrap(dest_vars)) + def compile(self, *args): self._function = super().compile(_unwrap(args)) return self diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 31f2f4720b4b812bb92d9f51b4c595233a24681f..42e6526ccb1274563b88c1afdd31088848857546 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -350,6 +350,7 @@ class trace: lazy_eval_graph.options.graph_opt_level = self._graph_opt_level else: lazy_eval_graph.options.graph_opt_level = 2 + lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers]) lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph() for r, x in zip(readers, lazy_eval_tensors): @@ -484,7 +485,8 @@ class trace: # graph.options.graph_opt_level = 0 need_reset_nodes = self._need_reset_nodes = [] # links enforce ordering of I/O nodes - links = () + in_out_links = () + io_links = () readers = [] if self._capture_as_const: @@ -499,7 +501,7 @@ class trace: ) need_reset_nodes.append(opnode) info.varnode = opnode.outputs[0] - links += opnode.outputs[1:] + in_out_links += opnode.outputs[1:] for op, ihandles, ohandles in self._seq: if isinstance(op, Const): @@ -536,7 +538,7 @@ class trace: ) else: opnode = info.data_setter = G.InputNode( - *links, + *in_out_links, device=info.device, dtype=info.dtype, shape=info.shape or (1,), @@ -544,45 +546,48 @@ class trace: use_static_shape=_input_node_use_static_shape(), ) need_reset_nodes.append(opnode) - info.varnode, *links = opnode.outputs - if require_links and i == 0 and len(links) > 0: - info.varnode = apply(VirtualDep(), info.varnode, *links)[0] - links = (info.varnode,) + info.varnode, *in_out_links = opnode.outputs + if require_links and i == 0 and len(io_links) > 0: + info.varnode = apply( + VirtualDep(str(io_links[0].device)), info.varnode, *io_links + )[0] + io_links = (info.varnode,) ivars.append(info.varnode) ovars = apply(op, *ivars) if require_links and len(ovars) > 0: - links = (ovars[0],) + io_links = (ovars[0],) assert len(ovars) == len(ohandles) for h, v in zip(ohandles, ovars): info = self._tinfo[h] info.varnode = v def add_reader(opnode): - nonlocal links + nonlocal in_out_links need_reset_nodes.append(opnode) readers.append(opnode.outputs[0]) - links = opnode.outputs + in_out_links = opnode.outputs if info.data_read: # Shape can be obtained from data so doesn't need its own # output node. On the other hand, value is read separately # to leverage eager h2d copy info.shape_read = False - opnode = info.data_reader = G.OutputNode(v, *links) + opnode = info.data_reader = G.OutputNode(v, *in_out_links) add_reader(opnode) if info.value_read: - opnode = info.value_reader = G.ValueOutputNode(v, *links) + opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) add_reader(opnode) if info.shape_read: - opnode = info.shape_reader = G.AttrOutputNode(v, *links) + opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) add_reader(opnode) # FIXME if self._graph_opt_level is not None: graph.options.graph_opt_level = self._graph_opt_level else: graph.options.graph_opt_level = 2 - graph.compile(*readers, *links) + graph.set_priority_to_id([*readers, *in_out_links, *io_links]) + graph.compile(*readers, *in_out_links, *io_links) def _reset_exec_env(self): for opnode in self._need_reset_nodes: @@ -1107,7 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): if require_links and active_trace._lazy_eval_links: assert len(ivars) > 0, "op should has at least one input" - ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0] + ivars[0] = apply( + VirtualDep(str(active_trace._lazy_eval_links[0].device)), + ivars[0], + *active_trace._lazy_eval_links, + )[0] active_trace._lazy_eval_links = (ivars[0],) ovars = apply(op, *ivars) diff --git a/imperative/python/megengine/utils/profiler.py b/imperative/python/megengine/utils/profiler.py index 3850630c4a3e74d0f7c6a60513a2b7d06bd44271..2a362e41754c2610508630510f018669e6df932f 100644 --- a/imperative/python/megengine/utils/profiler.py +++ b/imperative/python/megengine/utils/profiler.py @@ -246,7 +246,7 @@ class Profiler: value_type = type(value) if value_type in cls._type_map: value = cls._type_map[value_type](value) - results[attr] = value + results[attr] = str(value) return results def dump(self, path: Optional[str] = None): diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index bdd29f4c3cbbdac59dffe15c3e7fe115fbf234c7..ce99ed9a8483de3f75831a64d4a0cd15d294154d 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -10,6 +10,7 @@ */ #include "./ops.h" +#include #include "megbrain/imperative.h" #include "megbrain/imperative/ops/backward_graph.h" @@ -45,7 +46,8 @@ void init_ops(py::module m) { }); py::class_, OpDef>(m, "VirtualDep") - .def(py::init<>()); + .def(py::init<>()) + .def(py::init()); #include "opdef.py.inl" } diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index b9a701dadc03f1183b7dbd8e291c69e9ab623fea..65632713b6acb41f54e50532ff8436aebdcb18ba 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -10,6 +10,8 @@ */ #include "megbrain/imperative/ops/utility.h" +#include +#include "megbrain/comp_node.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/opr/utility.h" #include "../op_trait.h" @@ -20,9 +22,12 @@ namespace { cg::OperatorNodeBase* virtual_dep_apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& graph = inputs[0]->owner_graph(); - + auto&& op = def.cast_final_safe(); VarNodeArray inps(inputs.begin(), inputs.end()); cg::OperatorNodeConfig config; + if (op.device.length() > 0) { + config.comp_node(CompNode::load(op.device)); + } cg::OperatorNodeBase* opr = graph->insert_opr(std::make_unique( inps, config)); diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h index 817935218989e2ff646c141a0c8b4cb078adbd1f..4caecfce0fbe17245c323c5ea371c90e994cd2e3 100644 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -11,6 +11,8 @@ #pragma once +#include +#include "megbrain/graph/operator_node.h" #include "megbrain/imperative/op_def.h" #include "megbrain/utils/hash.h" @@ -22,6 +24,9 @@ class VirtualDep : public OpDefImplBase { public: VirtualDep() = default; + VirtualDep(std::string dev) : device(dev) {} + + std::string device; size_t hash() const override { return reinterpret_cast(dyn_typeinfo()); diff --git a/src/opr/impl/utility.cpp b/src/opr/impl/utility.cpp index bfac02e5f35763201cbd5b24e014dc1fe4c7572c..512e97423a5234712ab30116cb62ebbf2ee75271 100644 --- a/src/opr/impl/utility.cpp +++ b/src/opr/impl/utility.cpp @@ -206,15 +206,17 @@ SymbolVar Timestamp::make(SymbolVar node, std::shared_ptr dest, MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); VirtualDep::VirtualDep(const VarNodeArray& inputs, - const OperatorNodeConfig& config) + const OperatorNodeConfig& cfg) : Super(inputs[0]->owner_graph(), - setup_config_cn(config, inputs[0]->comp_node()), "virtual_dep", - inputs) { + cfg.has_comp_node_set() ? cfg : setup_config_cn(cfg, inputs[0]->comp_node()), + "virtual_dep", inputs) { for (auto inp : inputs) { add_input({inp}); } mgb_assert(inputs[0]->dtype().valid()); - add_output(None)->dtype(inputs[0]->dtype()); + add_output(None) + ->dtype(inputs[0]->dtype()) + .comp_node(config().get_single_comp_node()); } cg::OperatorNodeBase::NodeProp* VirtualDep::do_make_node_prop() const {