提交 4e80cf52 编写于 作者: M Megvii Engine Team

refactor(mge/jit): make trace return any kind of output

GitOrigin-RevId: fd1265c661e7f4d750f2c13599113b874c949ba5
上级 7f06bb94
......@@ -170,9 +170,9 @@ class trace:
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = {}
self._lazy_eval_tensors = set()
self._lazy_eval_links = None
self._active_tensors = {}
self._active_tensors = set()
self._tensor_remaps = None
self._inputs_to_restore = None
self._arg_bindings = None
......@@ -258,7 +258,7 @@ class trace:
y._compiled_info = CompiledTensorProxy(h)
y._mixin_handle = h
outputs += [y]
self._active_tensors[h] = TensorWeakRef(y)
self._active_tensors.add(TensorWeakRef(y))
self._output_handles.update(ohandles)
return outputs
......@@ -318,9 +318,9 @@ class trace:
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
self._active_tensors[h] = TensorWeakRef(x)
self._active_tensors.add(TensorWeakRef(x))
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)
self._lazy_eval_tensors.add(TensorWeakRef(x))
self._seq.append((op, tuple(ihandles), tuple(ohandles)))
......@@ -345,7 +345,7 @@ class trace:
x._recording = True
x._trace_mixin_info = info
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)
self._lazy_eval_tensors.add(TensorWeakRef(x))
self._seq.append(("Const", tuple(), tuple(ohandles)))
def _set_active(self, active: bool):
......@@ -365,17 +365,14 @@ class trace:
self._lazy_eval_links = ()
def _take_escaped_tensors(self):
escaped_tensors = tuple(
filter(lambda x: x() is not None, self._active_tensors.values())
)
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(
filter(lambda x: x() is not None, lazy_eval_tensors.values())
)
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
lazy_eval_tensors = [x() for x in lazy_eval_tensors]
lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
......@@ -383,8 +380,8 @@ class trace:
lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
# get values from lazy_eval_graph and assign to lazy_eval tensor
x()._handle = RawTensor(r.op.get_value())._handle
x()._reset_varnode()
x._handle = RawTensor(r.op.get_value())._handle
x._reset_varnode()
@contextlib.contextmanager
def _setup(self):
......@@ -454,13 +451,14 @@ class trace:
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
# reset output tensors
for x in self._active_tensors.values():
if x() is not None:
x()._dev_tensor()
x()._reset_varnode()
x()._mixin_handle = -1
x()._recording = False
x()._trace_mixin_info = None
for x in self._active_tensors.copy():
strong_x = x()
if strong_x is not None:
strong_x._dev_tensor()
strong_x._reset_varnode()
strong_x._mixin_handle = -1
strong_x._recording = False
strong_x._trace_mixin_info = None
try:
do_enter()
......@@ -482,15 +480,17 @@ class trace:
if self._untraced:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for x in self._active_tensors.values():
if x():
info = self._tinfo[x()._mixin_handle]
for x in self._active_tensors:
strong_x = x()
if strong_x:
info = self._tinfo[strong_x._mixin_handle]
info.exported = True
info.data_read = True
else:
for x in self._active_tensors.values():
if x():
x()._dev_tensor()
for x in self._active_tensors:
strong_x = x()
if strong_x:
strong_x._dev_tensor()
def _apply_graph_options(self, graph):
......@@ -520,7 +520,6 @@ class trace:
graph = self._graph = G.Graph()
graph.options.async_exec_level = 0b100
self._apply_graph_options(graph)
# graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes
in_out_links = ()
......@@ -563,7 +562,7 @@ class trace:
if not hasattr(info, "varnode"):
assert info.external
if info.bound_data:
if hasattr(info, "is_const") and info.is_const:
if getattr(info, "is_const", False):
info.varnode = graph.make_const(
info.bound_data.numpy(),
info.bound_data.dtype,
......@@ -635,30 +634,12 @@ class trace:
opnode.reset()
def __call__(self, *args, **kwargs):
if is_tracing():
return self.__wrapped__(*args, **kwargs)
with self._setup():
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
if self._capture_as_const:
self._process_outputs(outputs)
# outputs could be None
if outputs is not None:
list_outputs = outputs
if isinstance(outputs, collections.abc.Mapping):
_, list_outputs = zip(*sorted(outputs.items()))
elif not isinstance(outputs, collections.abc.Sequence):
list_outputs = (outputs,)
for o in list_outputs:
# if outputs are copied, then use the newest info in trace data structure
if o._copied:
self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic:
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
return outputs
def dump(
......
......@@ -9,11 +9,12 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp"
......@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) {
if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy();
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
outputs[i] = python::apply(op, outputs[i])[0];
}
// populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info;
......
......@@ -12,6 +12,7 @@
#pragma once
#include "./tensor.h"
#include "megbrain/imperative/ops/utility.h"
#include <megbrain/utils/small_vector.h>
#include <memory>
......
......@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma
return apply(ctx);
}
apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 1);
maker.output_size(1).output_captured(0, false);
maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (grad) {
ret[0] = grad->shared_from_this();
}
return ret;
});
return apply(ctx);
}
struct Init {
Init() {
auto& reg = grad_rule_registry();
......@@ -231,6 +246,7 @@ struct Init {
reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
}
} _;
......
......@@ -23,6 +23,7 @@
#include "./common.h"
#include "./ops.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/ops/utility.h"
namespace py = pybind11;
......
......@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) {
handles[i] = ctx.args[i]->m_handle.get();
}
apply_result_t outputs;
// fast copy without really applying
if (ctx.op->same_type<FastpathCopy>()) {
mgb_assert(ctx.nargs == 1);
outputs.reserve(ctx.nargs);
outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle));
return outputs;
}
auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
apply_result_t outputs;
outputs.reserve(output_handles.size());
for (auto h : output_handles) {
outputs.emplace_back(std::make_shared<Tensor>(h));
......@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject* TensorWrapper::copied() {
return py::cast(m_tensor->m_trace_info.copied).release().ptr();
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
if (m_tensor->m_trace_info.member) { \
......@@ -841,7 +845,6 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def<&TensorWrapper::_use_cnt>("_use_cnt")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
......
......@@ -10,6 +10,7 @@
*/
#pragma once
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include <variant>
......
......@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
for (size_t i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem));
}
......
......@@ -17,7 +17,6 @@ namespace mgb::imperative::python {
struct TraceInfo {
int64_t mixin_handle = -1;
bool recording = false;
bool copied = false;
// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject* compiled_info = nullptr;
......@@ -35,7 +34,6 @@ struct TraceInfo {
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
copied = true;
return *this;
}
......
......@@ -18,4 +18,18 @@ namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
namespace { namespace fastpathcopy {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
return inputs;
}
OP_TRAIT_REG(FastpathCopy,FastpathCopy)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fastpathcopy
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy);
} // namespace mgb::imperative
......@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
struct FastpathCopy final : OpDefImplBase<FastpathCopy> {
FastpathCopy() = default;
size_t hash() const override {
return mgb::hash(this->dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return this->dyn_typeinfo() == rhs.dyn_typeinfo();
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册