diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 37a863690bdd6cdb3e575cfefcae2576833a1054..5ad3391e64154cfcdfd038dca4bbc2f33b2c8fe2 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -12,6 +12,7 @@ import weakref from concurrent.futures import Future, ThreadPoolExecutor from .. import _imperative_rt +from .._imperative_rt.ops import BackwardGraph from .._wrap import device as as_device from ..ops.builtin import OpDef from .core import OpBase, TensorBase, apply @@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode): return _wrap(outputs) +@apply.register() +def _(op: BackwardGraph, *args: VarNode): + assert args + graph = args[0].graph + return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args) + + def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): outputs = _imperative_rt.input_callback( callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index abedeeea0ac49f113d0b622a9a82b218ef090b73..c64537959f4675148e2aada54ce3459df28b3fcb 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -40,6 +40,18 @@ void init_ops(py::module m) { attr.param.insert(attr.param.end(), s.begin(), s.end()); }); + py::class_, OpDef>(m, "BackwardGraph") + .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, + const mgb::SmallVector& inputs) { + auto f = [pyf](OpDef& op, const mgb::SmallVector& inputs) { + return py::cast>(pyf(op.copy(), inputs)); + }; + auto c = [pyc](const TensorPtr& tensor) { + return pyc(tensor->dev_tensor()); + }; + return self.graph().interpret(f, c, inputs); + }); + py::class_, OpDef>(m, "GetVarShape") .def(py::init()); @@ -98,7 +110,6 @@ void init_ops(py::module m) { .def(py::init<>()) .def_readwrite("offsets", &ParamPackConcat::offsets); - py::class_, OpDef>(m, "BackwardGraph"); py::class_, OpDef>(m, "CondTake") .def(py::init<>()); diff --git a/imperative/src/impl/ops/backward_graph.cpp b/imperative/src/impl/ops/backward_graph.cpp index e844d481ad0b8a2274c460b7d59e7a5ec8f511a7..e452ac5902f04f9679756afb3501eee471fbc2f2 100644 --- a/imperative/src/impl/ops/backward_graph.cpp +++ b/imperative/src/impl/ops/backward_graph.cpp @@ -18,34 +18,10 @@ namespace imperative { SmallVector BackwardGraph::InternalGraph::apply( const SmallVector& inputs) const { - ThinHashMap node2tensor; - auto&& input_nodes = this->inputs; - mgb_assert(inputs.size() == input_nodes.size()); - for (size_t i = 0; i < inputs.size(); ++ i) { - node2tensor[input_nodes[i]] = inputs[i]; - } - for (auto &&i : constants) { - node2tensor[i.first] = i.second; - } - for (size_t i = 0; i < exprs.size(); ++ i) { - auto&& expr = exprs[i]; - SmallVector inputs; - for (auto &&in : std::get<1>(expr)) { - inputs.push_back(node2tensor.at(in)); - } - auto outputs = OpDef::apply_on_physical_tensor( - *std::get<0>(expr), inputs); - auto output_nodes = std::get<2>(expr); - mgb_assert(outputs.size() == output_nodes.size()); - for (size_t i = 0; i < outputs.size(); ++ i) { - node2tensor[output_nodes[i]] = outputs[i]; - } - } - SmallVector ret; - for (auto &&i : outputs) { - ret.push_back(node2tensor.at(i)); - } - return ret; + return interpret( + &OpDef::apply_on_physical_tensor, + [](const TensorPtr& x) {return x;}, + inputs); } SmallVector diff --git a/imperative/src/include/megbrain/imperative/ops/backward_graph.h b/imperative/src/include/megbrain/imperative/ops/backward_graph.h index 3a8f00bf69863627afdb39dec8de9d954797a18c..e1d8876828f598587e4256b23107a1235ea2df92 100644 --- a/imperative/src/include/megbrain/imperative/ops/backward_graph.h +++ b/imperative/src/include/megbrain/imperative/ops/backward_graph.h @@ -40,6 +40,37 @@ public: SmallVector infer_attrs(const SmallVector& inputs) const; + + template + SmallVector interpret(F&& f, C&& c, const SmallVector& inputs) const { + ThinHashMap node2tensor; + auto&& input_nodes = this->inputs; + mgb_assert(inputs.size() == input_nodes.size()); + for (size_t i = 0; i < inputs.size(); ++ i) { + node2tensor[input_nodes[i]] = inputs[i]; + } + for (auto &&i : constants) { + node2tensor[i.first] = c(i.second); + } + for (size_t i = 0; i < exprs.size(); ++ i) { + auto&& expr = exprs[i]; + SmallVector inputs; + for (auto &&in : std::get<1>(expr)) { + inputs.push_back(node2tensor.at(in)); + } + auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); + auto&& output_nodes = std::get<2>(expr); + mgb_assert(outputs.size() == output_nodes.size()); + for (size_t i = 0; i < outputs.size(); ++ i) { + node2tensor[output_nodes[i]] = std::move(outputs[i]); + } + } + SmallVector ret; + for (auto &&i : outputs) { + ret.push_back(node2tensor.at(i)); + } + return ret; + } }; const InternalGraph& graph() const {