提交 241b35a6 编写于 作者: M Megvii Engine Team

refactor(ops): remove BackwardGraph op

GitOrigin-RevId: eda20e57606daad69790f6abbc7cd7fba2ba934c
上级 d2e33af5
......@@ -18,7 +18,6 @@ import numpy as np
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import TensorBase
......@@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode):
return _wrap(outputs)
def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
assert args
graph = args[0].graph
outputs = op.interpret(
op,
lambda op, args: apply_normal_varnode(op, *args),
graph._make_const_for_backward,
args,
)
return outputs
set_cpp_apply_backward_varnode(apply_backward_varnode)
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
outputs = _imperative_rt.input_callback(
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph
......
......@@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import (
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
......@@ -587,9 +587,6 @@ class trace:
ivars.append(info.varnode)
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
if require_links and len(ovars) > 0:
......@@ -805,9 +802,6 @@ class trace:
name=info.name,
)
ivars.append(h2v[h])
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
if isinstance(op, BatchNorm):
assert (
op.fwd_mode == BatchNorm.FwdMode.INFERENCE
......@@ -1088,9 +1082,6 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
ivars[0] = opnode.outputs[0]
active_trace._lazy_eval_links = (ivars[0],)
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
outputs = [RawTensor(o) for o in ovars]
......
......@@ -75,9 +75,9 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
input_requires_grad[i] = python::input_requires_grad(ctx, i);
}
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = proxy_graph_detail::make_backward_graph(
auto bg = OpDef::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad);
if (bg.backward) {
if (!bg.backward.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
backward_graph_cache.emplace(key, ret);
......@@ -112,7 +112,7 @@ struct BackwardGraphWithClosure {
size_t count = std::count_if(save_for_backward.begin(),
save_for_backward.end(),
ranges::identity{});
if (backward_graph->precomp) {
if (!backward_graph->precomp.empty()) {
auto&& irng = ranges::span(ctx.args, ctx.nargs);
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
......
......@@ -30,26 +30,14 @@ using namespace imperative;
using namespace interpreter;
namespace {
std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>
make_backward_graph(
const OpDef& opdef, std::vector<LogicalTensorDesc> inputs,
std::vector<bool> input_requires_grad,
std::vector<bool> output_has_grad) {
auto res = OpDef::make_backward_graph(opdef,
SmallVector<LogicalTensorDesc>(inputs.begin(), inputs.end()),
SmallVector<bool>(input_requires_grad.begin(), input_requires_grad.end()),
SmallVector<bool>(output_has_grad.begin(), output_has_grad.end()));
if (res.backward) {
return std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>{
std::in_place, res.backward, res.save_for_backward, res.input_has_grad};
} else {
return {};
}
}
} // namespace
void init_imperative_rt(py::module m) {
m.def("make_backward_graph", &make_backward_graph);
auto make_backward_graph = [](
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad){
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad);
};
m.def("make_backward_graph", make_backward_graph);
}
......@@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) {
}
/*********** begin of hand-write opdefs **************/
PyOpDefBegin(BackwardGraph) // {{
// };
PyOpDefEnd(BackwardGraph)
void _init_py_backward_graph(py::module m) {
using py_op = PyOp(BackwardGraph);
auto& py_type = PyOpType(BackwardGraph);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph";
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "BackwardGraph";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
mgb_assert(PyType_Ready(&py_type) >= 0);
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction
auto interpret = py::cpp_function(
[](OpDef& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
};
auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor());
};
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs);
});
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0);
PyType_Modified(&py_type);
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second);
}
struct PyOpBase : PyOpDef {
static PyTypeObject py_type;
......@@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
void init_ops(py::module m) {
_init_py_op_def(m);
_init_py_backward_graph(m);
_init_py_op_base(m);
INIT_ALL_OP(m)
......
......@@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
ctx.args = &tensors[0];
ctx.nargs = nargs;
ctx.pytype = pytype;
if (ctx.op->same_type<BackwardGraph>()) {
ctx.backward = true;
}
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
......
......@@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
return apply(ctx);
}
template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
apply_result_t> {
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.nargs = tensors.size();
Tensor* args[ctx.nargs];
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.nargs = nargs;
ctx.args = args;
template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
args[i] = resolve_arrow(tensors[i]);
}
return apply(ctx);
return apply(op, args, nargs);
}
inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
SmallVector<std::shared_ptr<Tensor>> inputs;
for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this());
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) {
return apply(op, inputs);
};
auto const_functor = [](imperative::TensorPtr value) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
};
return graph.apply(inputs, apply_functor, const_functor);
}
template <typename T>
auto apply(Subgraph graph, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>,
apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
}
return apply(graph, args, nargs);
}
void init_tensor(pybind11::module);
......
......@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;
if (ctx.backward) {
// call megbrain_graph.py apply(BackwardGraph, *args)
// reach here when compiled=True
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
......
......@@ -18,24 +18,22 @@ using namespace imperative;
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
: input_has_grad(src.input_has_grad) {
if (!src.backward->same_type<BackwardGraph>()) {
if (src.backward.exprs.size() <= 1) {
// backward graph only contains a single op
backward = src.backward;
save_for_backward = src.save_for_backward;
return;
}
save_for_backward.resize(src.save_for_backward.size(), false);
precomp.reset(new BackwardGraph);
backward.reset(new BackwardGraph);
auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph();
auto&& graph = src.backward;
auto&& mask = src.save_for_backward;
size_t input_size = src.input_has_grad.size();
size_t output_size = (mask.size() - input_size) / 2;
mgb_assert(input_size + output_size * 2 == mask.size());
auto& fgraph = precomp->cast_final<BackwardGraph>().graph();
auto& bgraph = backward->cast_final<BackwardGraph>().graph();
auto& fgraph = precomp;
auto& bgraph = backward;
// optimization: move ops (e.g. GetVarShape) to forward to
// reduce memory footprint
......@@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
}
if (!fgraph.outputs.size()) {
precomp.reset();
precomp = {};
}
}
......@@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
op_type == RemoteSend::typeinfo() ||
op_type == CollectiveComm::typeinfo() ||
op_type == opr::InputCallback::typeinfo() ||
op_type == opr::OutputCallback::typeinfo() ||
op_type == BackwardGraph::typeinfo()) {
op_type == opr::OutputCallback::typeinfo()) {
return m_commands.end();
}
} else if constexpr (std::is_same_v<T, GetValue>) {
......
......@@ -10,6 +10,9 @@
*/
#include "megbrain/imperative/op_def.h"
#include <sstream>
#include "megbrain/imperative/ops/opr_attr.h"
#include "./op_trait.h"
......@@ -117,6 +120,67 @@ const std::string OpDef::make_name() const {
return m_scope + "." + trait()->make_name(*this);
}
std::string Subgraph::repr() const {
std::ostringstream buf;
buf << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << inputs[i];
}
buf << ") => {\n";
auto fmt_const = [](size_t i, const TensorPtr& t) {
if (t->shape().ndim == 1 && t->shape()[0] == 1) {
auto&& v = t->get_value();
if (v.dtype() == dtype::Float32{}) {
return std::to_string(*v.ptr<dt_float32>());
} else if (v.dtype() == dtype::Int32{}) {
return std::to_string(*v.ptr<int32_t>());
}
}
return std::string("%c") + std::to_string(i);
};
std::unordered_map<size_t, std::string> const_reps;
for (auto&& [i, t] : constants) {
const_reps.emplace(i, fmt_const(i, t));
}
for (auto& [op, ins, outs] : exprs) {
buf << " ";
if (outs.size()) {
for (size_t i = 0; i < outs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outs[i];
}
buf << " = ";
}
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->dyn_typeinfo()->name;
}
for (size_t i : ins) {
buf << " ";
auto&& it = const_reps.find(i);
if (it != const_reps.end()) {
buf << it->second;
} else {
buf << "%" << i;
}
}
buf << "\n";
}
buf << " ";
if (outputs.size()) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outputs[i];
}
} else {
buf << "()";
}
buf << "\n}\n";
return buf.str();
}
} // namespace imperative
} // namespace mgb
......
......@@ -19,147 +19,6 @@
namespace mgb {
namespace imperative {
SmallVector<TensorPtr>
BackwardGraph::InternalGraph::apply(
const SmallVector<TensorPtr>& inputs) const {
return interpret<TensorPtr>(
&OpDef::apply_on_physical_tensor,
[](const TensorPtr& x) {return x;},
inputs);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const {
using TensorAttr = LogicalTensorDesc;
ThinHashMap<size_t, TensorAttr> node2attr;
auto&& input_nodes = this->inputs;
auto&& output_nodes = this->outputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2attr[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
auto* value = i.second->try_get_value();
mgb_assert(value);
node2attr[i.first] = TensorAttr{
i.second->layout(), i.second->comp_node(),
value->proxy_to_default_cpu()};
}
bool validated = true;
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& [expr_op, expr_inps, expr_oups] = exprs[i];
SmallVector<TensorAttr> expr_input_descs;
for (auto &&inp : expr_inps) {
expr_input_descs.push_back(node2attr.at(inp));
}
auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
*expr_op, expr_input_descs);
validated = validated && expr_validated;
mgb_assert(expr_output_descs.size() == expr_oups.size());
for (size_t i = 0; i < expr_output_descs.size(); ++ i) {
node2attr[expr_oups[i]] = expr_output_descs[i];
}
}
SmallVector<TensorAttr> ret;
for (auto &&i : output_nodes) {
ret.push_back(node2attr.at(i));
}
return {ret, validated};
}
std::string BackwardGraph::InternalGraph::repr() {
std::ostringstream buf;
buf << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << inputs[i];
}
buf << ") => {\n";
auto fmt_const = [](size_t i, TensorPtr& t) {
if (t->shape().ndim == 1 && t->shape()[0] == 1) {
auto&& v = t->get_value();
if (v.dtype() == dtype::Float32{}) {
return std::to_string(*v.ptr<dt_float32>());
} else if (v.dtype() == dtype::Int32{}) {
return std::to_string(*v.ptr<int32_t>());
}
}
return std::string("%c") + std::to_string(i);
};
std::unordered_map<size_t, std::string> const_reps;
for (auto&& [i, t] : constants) {
const_reps.emplace(i, fmt_const(i, t));
}
for (auto& [op, ins, outs] : exprs) {
buf << " ";
if (outs.size()) {
for (size_t i = 0; i < outs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outs[i];
}
buf << " = ";
}
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->dyn_typeinfo()->name;
}
for (size_t i : ins) {
buf << " ";
auto&& it = const_reps.find(i);
if (it != const_reps.end()) {
buf << it->second;
} else {
buf << "%" << i;
}
}
buf << "\n";
}
buf << " ";
if (outputs.size()) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outputs[i];
}
} else {
buf << "()";
}
buf << "\n}\n";
return buf.str();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph);
namespace {
SmallVector<TensorPtr> backward_impl(
const OpDef& backward_graph,
const SmallVector<TensorPtr>& tensors) {
return backward_graph.cast_final_safe<BackwardGraph>()
.graph().apply(tensors);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs(
const OpDef& backward_graph,
const SmallVector<LogicalTensorDesc> inputs) {
return backward_graph.cast_final_safe<BackwardGraph>()
.graph().infer_attrs(inputs);
}
std::vector<std::pair<const char*, std::string>> props(
const OpDef& backward_graph) {
return {};
}
OP_TRAIT_REG(BackwardGraph, BackwardGraph)
.apply_on_physical_tensor(backward_impl)
.infer_output_attrs_fallible(infer_tensor_attrs)
.props(props)
.fallback();
} // anonymous namespace
} // namespace imperative
} // namespace mgb
......
......@@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph(
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());
BackwardGraphResult result;
auto&& backward = BackwardGraph::make();
auto&& igraph = backward->cast_final_safe<BackwardGraph>().graph();
auto&& igraph = result.backward;
size_t nr_backward_graph_inputs = 0;
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
......@@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph(
++ nr_backward_graph_inputs;
push(op->output(0));
} else {
std::vector<size_t> inputs, outputs;
SmallVector<size_t> inputs, outputs;
for (auto &&i : op->input()) {
if (i->owner_opr() == fwd) {
if (var2idx.find(i) == var2idx.end()) {
......@@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph(
for (auto &&i : op->usable_output()) {
outputs.push_back(push(i));
}
igraph.exprs.emplace_back(OpDef::make_from_op_node(op), inputs, outputs);
igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs});
}
};
......@@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph(
write_inputs(outputs);
write_inputs(output_grads);
mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs);
auto treat_as_single = [](auto&& igraph) {
if (igraph.exprs.size() != 1)
return false;
auto&& expr = igraph.exprs[0];
auto&& expr_inputs = std::get<1>(expr);
if (expr_inputs.size() != igraph.inputs.size()) {
return false;
}
for (size_t i = 0; i < expr_inputs.size(); ++ i) {
if (igraph.inputs[i] != expr_inputs[i]) {
return false;
}
}
auto&& expr_outputs = std::get<2>(expr);
if (expr_outputs.size() != igraph.outputs.size()) {
return false;
}
for (size_t i = 0; i < expr_outputs.size(); ++ i) {
if (igraph.outputs[i] != expr_outputs[i]) {
return false;
}
}
return true;
};
if (treat_as_single(igraph)) {
result.backward = std::get<0>(igraph.exprs[0]);
} else {
result.backward = backward;
}
return result;
}
......
......@@ -65,7 +65,7 @@ private:
class InputPlaceholder;
struct ProxyGraphInst;
struct GradGraph;
struct CurOprGuard;
class CurOprGuard;
void reset();
......
......@@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph {
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph
struct ProxyGraph {
struct InputPlaceholder;
struct MiniGraph;
class MiniGraph;
};
} // namespace mgb::imperative::proxy_graph
......@@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def,
auto output_descs = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs(output_descs.size(), {});
for (size_t i = 0; i < outputs.size(); i++) {
auto& output = outputs[i];
auto& output_desc = output_descs[i];
if (def.same_type<Elemwise>()) {
for (size_t j = 0; j < inputs.size(); j++) {
// TODO: reindex inputs to support inplace exprs like 'y = x op x'.
auto& input = inputs[j];
// Because we pass inputs by value, if input and input->blob() are all unique,
// their ownerships are on the stack, thus we can reuse them safely.
// @see: interpreter::intl::ChannelImpl::process_one_task
if (input.unique() && input->blob().unique() && input->blob()->storage().unique() &&
input->layout().dtype == output_desc.layout.dtype &&
input->layout().eq_layout(output_desc.layout) &&
input->comp_node() == output_desc.comp_node) {
static std::atomic_llong inplace_count = 0;
mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld",
output_desc.layout.to_string().c_str(), ++inplace_count);
output = Tensor::make(input->blob(), input->layout(), input->offset());
break;
}
}
}
if (!output) {
output = Tensor::make(output_desc.layout, output_desc.comp_node);
}
outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
}
exec(def, inputs, outputs);
auto async_error = ProxyGraph::get_async_error();
......
......@@ -14,10 +14,10 @@
namespace mgb::imperative {
struct OptimizedBackwardGraphResult {
std::shared_ptr<OpDef> precomp;
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
std::vector<bool> input_has_grad;
Subgraph precomp;
Subgraph backward;
SmallVector<bool> save_for_backward;
SmallVector<bool> input_has_grad;
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph);
};
......
......@@ -26,10 +26,60 @@ enum DispatchMode {
KERNEL = 1
};
using SharedOp = std::shared_ptr<OpDef>;
template <typename T>
struct Expr {
std::shared_ptr<OpDef> op;
SmallVector<T> inputs;
SmallVector<T> outputs;
};
struct Subgraph {
SmallVector<size_t> inputs;
SmallVector<std::pair<size_t, TensorPtr>> constants;
SmallVector<size_t> outputs;
SmallVector<Expr<size_t>> exprs;
template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
std::unordered_map<size_t, T> idx2var;
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
idx2var[inputs[i]] = input_vars[i];
}
for (auto&& [idx, val]: constants) {
idx2var[idx] = c(val);
}
for (auto& expr: exprs) {
SmallVector<T> expr_inputs;
for (auto idx: expr.inputs) {
expr_inputs.push_back(idx2var[idx]);
}
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs));
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch");
for (size_t i = 0; i < expr_outputs.size(); ++i) {
idx2var[expr.outputs[i]] = expr_outputs[i];
}
}
SmallVector<T> output_vars;
for (auto idx: outputs) {
output_vars.push_back(idx2var[idx]);
}
return output_vars;
}
bool empty() const {
return outputs.size() == 0;
}
std::string repr() const;
};
struct BackwardGraphResult {
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
std::vector<bool> input_has_grad;
Subgraph backward;
SmallVector<bool> save_for_backward;
SmallVector<bool> input_has_grad;
};
class OpDef : public Hashable,
......
......@@ -15,92 +15,6 @@
namespace mgb {
namespace imperative {
// a special OpDef used for taking gradient on physical tensor
struct BackwardGraph final : public OpDefImplBase<BackwardGraph> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
struct InternalGraph {
// op, inputs, outputs
using Expr = std::tuple<std::shared_ptr<OpDef>,
std::vector<size_t>, std::vector<size_t>>;
std::vector<Expr> exprs;
// index array of input nodes
std::vector<size_t> inputs;
// index array of output nodes
std::vector<size_t> outputs;
// pair of (node index, correspending constant)
std::vector<std::pair<size_t, TensorPtr>> constants;
SmallVector<TensorPtr>
apply(const SmallVector<TensorPtr>& inputs) const;
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const;
template <typename T, typename F, typename C>
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const {
ThinHashMap<size_t, T> node2tensor;
auto&& input_nodes = this->inputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2tensor[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
node2tensor[i.first] = c(i.second);
}
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<T> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2tensor.at(in));
}
auto&& outputs = f(*std::get<0>(expr), std::move(inputs));
auto&& output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2tensor[output_nodes[i]] = std::move(outputs[i]);
}
}
SmallVector<T> ret;
for (auto &&i : outputs) {
ret.push_back(node2tensor.at(i));
}
return ret;
}
std::string repr();
};
const InternalGraph& graph() const {
return m_graph;
}
InternalGraph& graph() {
return m_graph;
}
bool is_same_st(const Hashable& rhs) const override {
if (!rhs.same_type<BackwardGraph>()) {
return false;
}
auto& other = rhs.cast_final_safe<BackwardGraph>();
if (this == &other) {
return true;
}
// FIXME
return false;
}
std::string repr() {return m_graph.repr();}
private:
InternalGraph m_graph;
};
} // namespace imperative
} // namespace mgb
......
......@@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
}
bool is_same_st(const Hashable& rhs) const override {
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj);
return obj.equal(rhs.cast_final<GenericPyOp>().obj);
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons
return ret;
}
SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) {
return OpDef::apply_on_physical_tensor(*def, inputs);
}
TEST(TestImperative, BackwardGraphBasic) {
HostTensorGenerator<> gen;
SmallVector<HostTensorND> hvs;
......@@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs.clear();
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs);
auto input_grads = result.backward.apply(
backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
mgb_assert(input_grads.size() == input_has_grad.size());
for (size_t i = 0; i < input_has_grad.size(); ++ i) {
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
......@@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs.clear();
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs);
auto input_grads = result.backward.apply(
backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
mgb_assert(input_grads.size() == input_has_grad.size());
for (size_t i = 0; i < input_has_grad.size(); ++ i) {
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
......@@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0];
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs));
auto grads = expand_grads(bg, bg.backward.apply(
backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
));
auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn});
auto precomp = obg.precomp.apply(
SmallVector<TensorPtr>{a_tn, b_tn, c_tn},
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
ASSERT_EQ(precomp.size(), 2);
ASSERT_EQ(precomp[0]->shape().ndim, 1);
ASSERT_LE(precomp[0]->shape()[0], 2);
......@@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
ASSERT_LE(precomp[1]->shape()[0], 2);
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs));
auto grads2 = expand_grads(obg, obg.backward.apply(
backward_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
));
ASSERT_EQ(grads2.size(), 2);
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册