diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 485407eed474864aa306372193c141a448791ed2..f7781495fb35ef9e165317ec079966c942f26b4f 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -569,3 +569,9 @@ class AttrOutputNode(OpNode): def reset(self): self._rendezvous.reset() + + +class VirtualDepNode(OpNode): + def __init__(self, vars, device=""): + out = _imperative_rt.virtual_dep(_unwrap(vars), device) + super().__init__(out) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 42e6526ccb1274563b88c1afdd31088848857546..5ef0fc501a0dd2e4279cc12f7e6ead94cea7eb55 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -25,7 +25,6 @@ from ..core._imperative_rt.ops import ( RemoteRecv, RemoteSend, UniformRNG, - VirtualDep, ) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device @@ -548,9 +547,10 @@ class trace: need_reset_nodes.append(opnode) info.varnode, *in_out_links = opnode.outputs if require_links and i == 0 and len(io_links) > 0: - info.varnode = apply( - VirtualDep(str(io_links[0].device)), info.varnode, *io_links - )[0] + opnode = G.VirtualDepNode( + [info.varnode, *io_links], str(io_links[0].device) + ) + info.varnode = opnode.outputs[0] io_links = (info.varnode,) ivars.append(info.varnode) @@ -1112,11 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): if require_links and active_trace._lazy_eval_links: assert len(ivars) > 0, "op should has at least one input" - ivars[0] = apply( - VirtualDep(str(active_trace._lazy_eval_links[0].device)), - ivars[0], - *active_trace._lazy_eval_links, - )[0] + opnode = G.VirtualDepNode( + [ivars[0], *active_trace._lazy_eval_links], + str(active_trace._lazy_eval_links[0].device), + ) + ivars[0] = opnode.outputs[0] active_trace._lazy_eval_links = (ivars[0],) ovars = apply(op, *ivars) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 73f58373bb6ee307b3deb158ddf103754a1791da..170a90697b1b5a1301b3f22e3a5a7dfb482c80ae 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -15,6 +15,7 @@ #include "megbrain/serialization/serializer.h" #include "megbrain/imperative/opr_utility.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/utility.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/imperative.h" #include "./helper.h" @@ -562,4 +563,16 @@ void init_graph_rt(py::module m) { }; return output_callback(std::move(f), std::move(inputs), p, true); }); + + m.def("virtual_dep", [](std::vector inputs, std::string device) { + auto&& graph = inputs[0]->owner_graph(); + VarNodeArray inps(inputs.begin(), inputs.end()); + cg::OperatorNodeConfig config; + if (device.length() > 0) { + config.comp_node(CompNode::load(device)); + } + cg::OperatorNodeBase* opr = graph->insert_opr( + std::make_unique(inps, config)); + return opr; + }); } diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index ce99ed9a8483de3f75831a64d4a0cd15d294154d..a12b953a0b0273e054edbb4489196936cddbb7d2 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -10,12 +10,10 @@ */ #include "./ops.h" -#include #include "megbrain/imperative.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" -#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/autogen.h" namespace py = pybind11; @@ -45,9 +43,5 @@ void init_ops(py::module m) { return self.graph().interpret(f, c, inputs); }); - py::class_, OpDef>(m, "VirtualDep") - .def(py::init<>()) - .def(py::init()); - #include "opdef.py.inl" } diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp deleted file mode 100644 index 65632713b6acb41f54e50532ff8436aebdcb18ba..0000000000000000000000000000000000000000 --- a/imperative/src/impl/ops/utility.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/** - * \file imperative/src/impl/ops/utility.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include "megbrain/imperative/ops/utility.h" -#include -#include "megbrain/comp_node.h" -#include "megbrain/imperative/ops/opr_attr.h" -#include "megbrain/opr/utility.h" -#include "../op_trait.h" - -namespace mgb::imperative { -namespace { - -cg::OperatorNodeBase* virtual_dep_apply_on_var_node( - const OpDef& def, const VarNodeArray& inputs) { - auto&& graph = inputs[0]->owner_graph(); - auto&& op = def.cast_final_safe(); - VarNodeArray inps(inputs.begin(), inputs.end()); - cg::OperatorNodeConfig config; - if (op.device.length() > 0) { - config.comp_node(CompNode::load(op.device)); - } - cg::OperatorNodeBase* opr = - graph->insert_opr(std::make_unique( - inps, config)); - return opr; -} - -OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) - .apply_on_var_node(virtual_dep_apply_on_var_node) - .fallback(); -} // namespace - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); - -} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h deleted file mode 100644 index 4caecfce0fbe17245c323c5ea371c90e994cd2e3..0000000000000000000000000000000000000000 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/utility.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include -#include "megbrain/graph/operator_node.h" -#include "megbrain/imperative/op_def.h" - -#include "megbrain/utils/hash.h" - -namespace mgb::imperative { - -class VirtualDep : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; - -public: - VirtualDep() = default; - VirtualDep(std::string dev) : device(dev) {} - - std::string device; - - size_t hash() const override { - return reinterpret_cast(dyn_typeinfo()); - } - - bool is_same_st(const Hashable& rhs) const override { - return true; - } -}; - -} // namespace mgb::imperative