提交 4f240ec2 编写于 作者: M Megvii Engine Team

refactor(mge/jit): make trace return any kind of output

GitOrigin-RevId: fd1265c661e7f4d750f2c13599113b874c949ba5
上级 c6b552cf
...@@ -170,9 +170,9 @@ class trace: ...@@ -170,9 +170,9 @@ class trace:
self._graph = None self._graph = None
self._need_reset_nodes = None self._need_reset_nodes = None
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = {} self._lazy_eval_tensors = set()
self._lazy_eval_links = None self._lazy_eval_links = None
self._active_tensors = {} self._active_tensors = set()
self._tensor_remaps = None self._tensor_remaps = None
self._inputs_to_restore = None self._inputs_to_restore = None
self._arg_bindings = None self._arg_bindings = None
...@@ -258,7 +258,7 @@ class trace: ...@@ -258,7 +258,7 @@ class trace:
y._compiled_info = CompiledTensorProxy(h) y._compiled_info = CompiledTensorProxy(h)
y._mixin_handle = h y._mixin_handle = h
outputs += [y] outputs += [y]
self._active_tensors[h] = TensorWeakRef(y) self._active_tensors.add(TensorWeakRef(y))
self._output_handles.update(ohandles) self._output_handles.update(ohandles)
return outputs return outputs
...@@ -318,9 +318,9 @@ class trace: ...@@ -318,9 +318,9 @@ class trace:
x._mixin_handle = h x._mixin_handle = h
x._recording = True x._recording = True
x._trace_mixin_info = info x._trace_mixin_info = info
self._active_tensors[h] = TensorWeakRef(x) self._active_tensors.add(TensorWeakRef(x))
if self._symbolic: if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x) self._lazy_eval_tensors.add(TensorWeakRef(x))
self._seq.append((op, tuple(ihandles), tuple(ohandles))) self._seq.append((op, tuple(ihandles), tuple(ohandles)))
...@@ -345,7 +345,7 @@ class trace: ...@@ -345,7 +345,7 @@ class trace:
x._recording = True x._recording = True
x._trace_mixin_info = info x._trace_mixin_info = info
if self._symbolic: if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x) self._lazy_eval_tensors.add(TensorWeakRef(x))
self._seq.append(("Const", tuple(), tuple(ohandles))) self._seq.append(("Const", tuple(), tuple(ohandles)))
def _set_active(self, active: bool): def _set_active(self, active: bool):
...@@ -365,17 +365,14 @@ class trace: ...@@ -365,17 +365,14 @@ class trace:
self._lazy_eval_links = () self._lazy_eval_links = ()
def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple( escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
filter(lambda x: x() is not None, self._active_tensors.values())
)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list( lazy_eval_tensors = [x() for x in lazy_eval_tensors]
filter(lambda x: x() is not None, lazy_eval_tensors.values()) lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
) readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph) self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
...@@ -383,8 +380,8 @@ class trace: ...@@ -383,8 +380,8 @@ class trace:
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
# get values from lazy_eval_graph and assign to lazy_eval tensor # get values from lazy_eval_graph and assign to lazy_eval tensor
x()._handle = RawTensor(r.op.get_value())._handle x._handle = RawTensor(r.op.get_value())._handle
x()._reset_varnode() x._reset_varnode()
@contextlib.contextmanager @contextlib.contextmanager
def _setup(self): def _setup(self):
...@@ -454,13 +451,14 @@ class trace: ...@@ -454,13 +451,14 @@ class trace:
raise TraceMismatchError("premature end") raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced: if not self._symbolic or not self._untraced:
# reset output tensors # reset output tensors
for x in self._active_tensors.values(): for x in self._active_tensors.copy():
if x() is not None: strong_x = x()
x()._dev_tensor() if strong_x is not None:
x()._reset_varnode() strong_x._dev_tensor()
x()._mixin_handle = -1 strong_x._reset_varnode()
x()._recording = False strong_x._mixin_handle = -1
x()._trace_mixin_info = None strong_x._recording = False
strong_x._trace_mixin_info = None
try: try:
do_enter() do_enter()
...@@ -482,15 +480,17 @@ class trace: ...@@ -482,15 +480,17 @@ class trace:
if self._untraced: if self._untraced:
# conditionally reading a compiled tensor in excluded region # conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read # is permitted, so we have to assume every tensor might be read
for x in self._active_tensors.values(): for x in self._active_tensors:
if x(): strong_x = x()
info = self._tinfo[x()._mixin_handle] if strong_x:
info = self._tinfo[strong_x._mixin_handle]
info.exported = True info.exported = True
info.data_read = True info.data_read = True
else: else:
for x in self._active_tensors.values(): for x in self._active_tensors:
if x(): strong_x = x()
x()._dev_tensor() if strong_x:
strong_x._dev_tensor()
def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):
...@@ -520,7 +520,6 @@ class trace: ...@@ -520,7 +520,6 @@ class trace:
graph = self._graph = G.Graph() graph = self._graph = G.Graph()
graph.options.async_exec_level = 0b100 graph.options.async_exec_level = 0b100
self._apply_graph_options(graph) self._apply_graph_options(graph)
# graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = [] need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes # links enforce ordering of I/O nodes
in_out_links = () in_out_links = ()
...@@ -563,7 +562,7 @@ class trace: ...@@ -563,7 +562,7 @@ class trace:
if not hasattr(info, "varnode"): if not hasattr(info, "varnode"):
assert info.external assert info.external
if info.bound_data: if info.bound_data:
if hasattr(info, "is_const") and info.is_const: if getattr(info, "is_const", False):
info.varnode = graph.make_const( info.varnode = graph.make_const(
info.bound_data.numpy(), info.bound_data.numpy(),
info.bound_data.dtype, info.bound_data.dtype,
...@@ -635,30 +634,12 @@ class trace: ...@@ -635,30 +634,12 @@ class trace:
opnode.reset() opnode.reset()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if is_tracing():
return self.__wrapped__(*args, **kwargs)
with self._setup(): with self._setup():
if self._capture_as_const: if self._capture_as_const:
self._process_inputs(*args, **kwargs) self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs)
if self._capture_as_const: if self._capture_as_const:
self._process_outputs(outputs) self._process_outputs(outputs)
# outputs could be None
if outputs is not None:
list_outputs = outputs
if isinstance(outputs, collections.abc.Mapping):
_, list_outputs = zip(*sorted(outputs.items()))
elif not isinstance(outputs, collections.abc.Sequence):
list_outputs = (outputs,)
for o in list_outputs:
# if outputs are copied, then use the newest info in trace data structure
if o._copied:
self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic:
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
return outputs return outputs
def dump( def dump(
......
...@@ -9,11 +9,12 @@ ...@@ -9,11 +9,12 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "./grad.h" #include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/backward_graph_opt.h" #include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp" #include "range/v3/all.hpp"
...@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) { ...@@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) {
if (backward.output_requires_grad(i)) { if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) { if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn] // avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy(); static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
outputs[i] = python::apply(op, outputs[i])[0];
} }
// populate grad info of output tensor // populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info; auto& grad_info = outputs[i]->m_grad_info;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "./tensor.h" #include "./tensor.h"
#include "megbrain/imperative/ops/utility.h"
#include <megbrain/utils/small_vector.h> #include <megbrain/utils/small_vector.h>
#include <memory> #include <memory>
......
...@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma ...@@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma
return apply(ctx); return apply(ctx);
} }
apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
mgb_assert(ctx.nargs == 1);
maker.output_size(1).output_captured(0, false);
maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (grad) {
ret[0] = grad->shared_from_this();
}
return ret;
});
return apply(ctx);
}
struct Init { struct Init {
Init() { Init() {
auto& reg = grad_rule_registry(); auto& reg = grad_rule_registry();
...@@ -231,6 +246,7 @@ struct Init { ...@@ -231,6 +246,7 @@ struct Init {
reg.emplace(Reduce::typeinfo(), reduce_grad_rule); reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
} }
} _; } _;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "./common.h" #include "./common.h"
#include "./ops.h" #include "./ops.h"
#include "megbrain/gopt/inference.h" #include "megbrain/gopt/inference.h"
#include "megbrain/imperative/ops/utility.h"
namespace py = pybind11; namespace py = pybind11;
......
...@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) { ...@@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) {
handles[i] = ctx.args[i]->m_handle.get(); handles[i] = ctx.args[i]->m_handle.get();
} }
apply_result_t outputs;
// fast copy without really applying
if (ctx.op->same_type<FastpathCopy>()) {
mgb_assert(ctx.nargs == 1);
outputs.reserve(ctx.nargs);
outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle));
return outputs;
}
auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
apply_result_t outputs;
outputs.reserve(output_handles.size()); outputs.reserve(output_handles.size());
for (auto h : output_handles) { for (auto h : output_handles) {
outputs.emplace_back(std::make_shared<Tensor>(h)); outputs.emplace_back(std::make_shared<Tensor>(h));
...@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) ...@@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC #undef REGISTE_TENSORWRAPPER_FUNC
PyObject* TensorWrapper::copied() {
return py::cast(m_tensor->m_trace_info.copied).release().ptr();
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \ PyObject* TensorWrapper::member() { \
if (m_tensor->m_trace_info.member) { \ if (m_tensor->m_trace_info.member) { \
...@@ -841,7 +845,6 @@ void init_tensor(py::module m) { ...@@ -841,7 +845,6 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def<&TensorWrapper::_use_cnt>("_use_cnt") .def<&TensorWrapper::_use_cnt>("_use_cnt")
.def_getset<&TensorWrapper::varnode>("_varnode") .def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#pragma once #pragma once
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include <variant> #include <variant>
......
...@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret); auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) { for (size_t i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode*>(); auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem)); outputs.emplace_back(std::make_shared<Tensor>(pitem));
} }
......
...@@ -17,7 +17,6 @@ namespace mgb::imperative::python { ...@@ -17,7 +17,6 @@ namespace mgb::imperative::python {
struct TraceInfo { struct TraceInfo {
int64_t mixin_handle = -1; int64_t mixin_handle = -1;
bool recording = false; bool recording = false;
bool copied = false;
// refer to CompiledTensorProxy in tracing.py, works from second trace step // refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject* compiled_info = nullptr; PyObject* compiled_info = nullptr;
...@@ -35,7 +34,6 @@ struct TraceInfo { ...@@ -35,7 +34,6 @@ struct TraceInfo {
compiled_info = that.compiled_info; compiled_info = that.compiled_info;
Py_XINCREF(compiled_info); Py_XINCREF(compiled_info);
copied = true;
return *this; return *this;
} }
......
...@@ -18,4 +18,18 @@ namespace mgb::imperative { ...@@ -18,4 +18,18 @@ namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
namespace { namespace fastpathcopy {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
return inputs;
}
OP_TRAIT_REG(FastpathCopy,FastpathCopy)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fastpathcopy
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy);
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { ...@@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
}; };
struct FastpathCopy final : OpDefImplBase<FastpathCopy> {
FastpathCopy() = default;
size_t hash() const override {
return mgb::hash(this->dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return this->dyn_typeinfo() == rhs.dyn_typeinfo();
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
} // namespace mgb::imperative } // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册