提交 4101d5bc 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): add BackwardGraph.interpret

GitOrigin-RevId: bb3a59380ec937c7fd60daed161d3f41172da972
上级 afddefb6
......@@ -12,6 +12,7 @@ import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from .. import _imperative_rt
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import OpBase, TensorBase, apply
......@@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode):
return _wrap(outputs)
@apply.register()
def _(op: BackwardGraph, *args: VarNode):
assert args
graph = args[0].graph
return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args)
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
outputs = _imperative_rt.input_callback(
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph
......
......@@ -40,6 +40,18 @@ void init_ops(py::module m) {
attr.param.insert(attr.param.end(), s.begin(), s.end());
});
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph")
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs));
};
auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor());
};
return self.graph().interpret<py::object>(f, c, inputs);
});
py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape")
.def(py::init());
......@@ -98,7 +110,6 @@ void init_ops(py::module m) {
.def(py::init<>())
.def_readwrite("offsets", &ParamPackConcat::offsets);
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph");
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake")
.def(py::init<>());
......
......@@ -18,34 +18,10 @@ namespace imperative {
SmallVector<TensorPtr>
BackwardGraph::InternalGraph::apply(
const SmallVector<TensorPtr>& inputs) const {
ThinHashMap<size_t, TensorPtr> node2tensor;
auto&& input_nodes = this->inputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2tensor[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
node2tensor[i.first] = i.second;
}
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<TensorPtr> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2tensor.at(in));
}
auto outputs = OpDef::apply_on_physical_tensor(
*std::get<0>(expr), inputs);
auto output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2tensor[output_nodes[i]] = outputs[i];
}
}
SmallVector<TensorPtr> ret;
for (auto &&i : outputs) {
ret.push_back(node2tensor.at(i));
}
return ret;
return interpret<TensorPtr>(
&OpDef::apply_on_physical_tensor,
[](const TensorPtr& x) {return x;},
inputs);
}
SmallVector<LogicalTensorDesc>
......
......@@ -40,6 +40,37 @@ public:
SmallVector<LogicalTensorDesc>
infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const;
template <typename T, typename F, typename C>
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const {
ThinHashMap<size_t, T> node2tensor;
auto&& input_nodes = this->inputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2tensor[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
node2tensor[i.first] = c(i.second);
}
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<T> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2tensor.at(in));
}
auto&& outputs = f(*std::get<0>(expr), std::move(inputs));
auto&& output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2tensor[output_nodes[i]] = std::move(outputs[i]);
}
}
SmallVector<T> ret;
for (auto &&i : outputs) {
ret.push_back(node2tensor.at(i));
}
return ret;
}
};
const InternalGraph& graph() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册