From 4f240ec2d3e65dc3b9c1bb0d813073f1181bb5ef Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 6 Apr 2021 15:09:40 +0800 Subject: [PATCH] refactor(mge/jit): make trace return any kind of output GitOrigin-RevId: fd1265c661e7f4d750f2c13599113b874c949ba5 --- imperative/python/megengine/jit/tracing.py | 77 +++++++------------ imperative/python/src/grad.cpp | 6 +- imperative/python/src/grad.h | 1 + imperative/python/src/grad_override.cpp | 16 ++++ imperative/python/src/graph_rt.cpp | 1 + imperative/python/src/tensor.cpp | 17 ++-- imperative/python/src/tensor.h | 1 + imperative/python/src/trace.cpp | 2 +- imperative/python/src/trace_info.h | 2 - imperative/src/impl/ops/utility.cpp | 14 ++++ .../include/megbrain/imperative/ops/utility.h | 14 ++++ 11 files changed, 91 insertions(+), 60 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index c95f5402f..2d3c0ca3d 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -170,9 +170,9 @@ class trace: self._graph = None self._need_reset_nodes = None self._lazy_eval_graph = None - self._lazy_eval_tensors = {} + self._lazy_eval_tensors = set() self._lazy_eval_links = None - self._active_tensors = {} + self._active_tensors = set() self._tensor_remaps = None self._inputs_to_restore = None self._arg_bindings = None @@ -258,7 +258,7 @@ class trace: y._compiled_info = CompiledTensorProxy(h) y._mixin_handle = h outputs += [y] - self._active_tensors[h] = TensorWeakRef(y) + self._active_tensors.add(TensorWeakRef(y)) self._output_handles.update(ohandles) return outputs @@ -318,9 +318,9 @@ class trace: x._mixin_handle = h x._recording = True x._trace_mixin_info = info - self._active_tensors[h] = TensorWeakRef(x) + self._active_tensors.add(TensorWeakRef(x)) if self._symbolic: - self._lazy_eval_tensors[h] = TensorWeakRef(x) + self._lazy_eval_tensors.add(TensorWeakRef(x)) self._seq.append((op, tuple(ihandles), tuple(ohandles))) @@ -345,7 +345,7 @@ class trace: x._recording = True x._trace_mixin_info = info if self._symbolic: - self._lazy_eval_tensors[h] = TensorWeakRef(x) + self._lazy_eval_tensors.add(TensorWeakRef(x)) self._seq.append(("Const", tuple(), tuple(ohandles))) def _set_active(self, active: bool): @@ -365,17 +365,14 @@ class trace: self._lazy_eval_links = () def _take_escaped_tensors(self): - escaped_tensors = tuple( - filter(lambda x: x() is not None, self._active_tensors.values()) - ) + escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) self._active_tensors.clear() return escaped_tensors def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): - lazy_eval_tensors = list( - filter(lambda x: x() is not None, lazy_eval_tensors.values()) - ) - readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] + lazy_eval_tensors = [x() for x in lazy_eval_tensors] + lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None] + readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors] self._apply_graph_options(lazy_eval_graph) lazy_eval_graph.options.graph_opt_level = self._graph_opt_level lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers]) @@ -383,8 +380,8 @@ class trace: lazy_eval_graph() for r, x in zip(readers, lazy_eval_tensors): # get values from lazy_eval_graph and assign to lazy_eval tensor - x()._handle = RawTensor(r.op.get_value())._handle - x()._reset_varnode() + x._handle = RawTensor(r.op.get_value())._handle + x._reset_varnode() @contextlib.contextmanager def _setup(self): @@ -454,13 +451,14 @@ class trace: raise TraceMismatchError("premature end") if not self._symbolic or not self._untraced: # reset output tensors - for x in self._active_tensors.values(): - if x() is not None: - x()._dev_tensor() - x()._reset_varnode() - x()._mixin_handle = -1 - x()._recording = False - x()._trace_mixin_info = None + for x in self._active_tensors.copy(): + strong_x = x() + if strong_x is not None: + strong_x._dev_tensor() + strong_x._reset_varnode() + strong_x._mixin_handle = -1 + strong_x._recording = False + strong_x._trace_mixin_info = None try: do_enter() @@ -482,15 +480,17 @@ class trace: if self._untraced: # conditionally reading a compiled tensor in excluded region # is permitted, so we have to assume every tensor might be read - for x in self._active_tensors.values(): - if x(): - info = self._tinfo[x()._mixin_handle] + for x in self._active_tensors: + strong_x = x() + if strong_x: + info = self._tinfo[strong_x._mixin_handle] info.exported = True info.data_read = True else: - for x in self._active_tensors.values(): - if x(): - x()._dev_tensor() + for x in self._active_tensors: + strong_x = x() + if strong_x: + strong_x._dev_tensor() def _apply_graph_options(self, graph): @@ -520,7 +520,6 @@ class trace: graph = self._graph = G.Graph() graph.options.async_exec_level = 0b100 self._apply_graph_options(graph) - # graph.options.graph_opt_level = 0 need_reset_nodes = self._need_reset_nodes = [] # links enforce ordering of I/O nodes in_out_links = () @@ -563,7 +562,7 @@ class trace: if not hasattr(info, "varnode"): assert info.external if info.bound_data: - if hasattr(info, "is_const") and info.is_const: + if getattr(info, "is_const", False): info.varnode = graph.make_const( info.bound_data.numpy(), info.bound_data.dtype, @@ -635,30 +634,12 @@ class trace: opnode.reset() def __call__(self, *args, **kwargs): - if is_tracing(): - return self.__wrapped__(*args, **kwargs) with self._setup(): if self._capture_as_const: self._process_inputs(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs) if self._capture_as_const: self._process_outputs(outputs) - - # outputs could be None - if outputs is not None: - list_outputs = outputs - if isinstance(outputs, collections.abc.Mapping): - _, list_outputs = zip(*sorted(outputs.items())) - elif not isinstance(outputs, collections.abc.Sequence): - list_outputs = (outputs,) - - for o in list_outputs: - # if outputs are copied, then use the newest info in trace data structure - if o._copied: - self._active_tensors[o._mixin_handle] = TensorWeakRef(o) - if self._untraced and self._symbolic: - self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) - return outputs def dump( diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 80e37376b..cffdbc571 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -9,11 +9,12 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" + #include "./grad.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/backward_graph_opt.h" #include "megbrain/imperative/ops/autogen.h" -#include "megbrain/imperative/ops/utility.h" #include "megbrain/utils/mempool.h" #include "range/v3/all.hpp" @@ -434,7 +435,8 @@ apply_result_t apply_grad(ApplyContext& ctx) { if (backward.output_requires_grad(i)) { if (backward.output_captured(i)) { // avoid reference cycle [Tensor <-> GradFn] - outputs[i] = outputs[i]->copy(); + static std::shared_ptr op = std::shared_ptr(new FastpathCopy()); + outputs[i] = python::apply(op, outputs[i])[0]; } // populate grad info of output tensor auto& grad_info = outputs[i]->m_grad_info; diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index 1780f311c..a3fb58e18 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -12,6 +12,7 @@ #pragma once #include "./tensor.h" +#include "megbrain/imperative/ops/utility.h" #include #include diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index ccc5a5078..5a54623ee 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -221,6 +221,21 @@ apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& ma return apply(ctx); } +apply_result_t fastpathcopy_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + mgb_assert(ctx.nargs == 1); + maker.output_size(1).output_captured(0, false); + maker.backward([](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + Tensor* grad = grads[0]; + apply_result_t ret(1); + if (grad) { + ret[0] = grad->shared_from_this(); + } + return ret; + }); + return apply(ctx); +} + struct Init { Init() { auto& reg = grad_rule_registry(); @@ -231,6 +246,7 @@ struct Init { reg.emplace(Reduce::typeinfo(), reduce_grad_rule); reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule); reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule); + reg.emplace(FastpathCopy::typeinfo(), fastpathcopy_grad_rule); } } _; diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 15768f7ec..6369b4903 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -23,6 +23,7 @@ #include "./common.h" #include "./ops.h" #include "megbrain/gopt/inference.h" +#include "megbrain/imperative/ops/utility.h" namespace py = pybind11; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 5a3da4112..9ba144fc0 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -118,9 +118,18 @@ apply_result_t apply(ApplyContext& ctx) { handles[i] = ctx.args[i]->m_handle.get(); } + apply_result_t outputs; + + // fast copy without really applying + if (ctx.op->same_type()) { + mgb_assert(ctx.nargs == 1); + outputs.reserve(ctx.nargs); + outputs.emplace_back(std::make_shared(ctx.args[0]->m_handle)); + return outputs; + } + auto output_handles = interpreter_for_py->apply_op(ctx.op, handles); - apply_result_t outputs; outputs.reserve(output_handles.size()); for (auto h : output_handles) { outputs.emplace_back(std::make_shared(h)); @@ -303,11 +312,6 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) #undef REGISTE_TENSORWRAPPER_FUNC -PyObject* TensorWrapper::copied() { - return py::cast(m_tensor->m_trace_info.copied).release().ptr(); -} - - #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ PyObject* TensorWrapper::member() { \ if (m_tensor->m_trace_info.member) { \ @@ -841,7 +845,6 @@ void init_tensor(py::module m) { .def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::_use_cnt>("_use_cnt") .def_getset<&TensorWrapper::varnode>("_varnode") - .def_getset<&TensorWrapper::copied>("_copied") .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index f2a365689..2b0d0be10 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -10,6 +10,7 @@ */ #pragma once +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" #include diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index 9571c5cda..8f597f1b1 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -35,7 +35,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { // assumption: python function always returns PyList auto tup = py::reinterpret_borrow(ret); - for (auto i = 0; i < tup.size(); i++) { + for (size_t i = 0; i < tup.size(); i++) { auto pitem = tup[i].cast(); outputs.emplace_back(std::make_shared(pitem)); } diff --git a/imperative/python/src/trace_info.h b/imperative/python/src/trace_info.h index 11736ca7b..e394d344d 100644 --- a/imperative/python/src/trace_info.h +++ b/imperative/python/src/trace_info.h @@ -17,7 +17,6 @@ namespace mgb::imperative::python { struct TraceInfo { int64_t mixin_handle = -1; bool recording = false; - bool copied = false; // refer to CompiledTensorProxy in tracing.py, works from second trace step PyObject* compiled_info = nullptr; @@ -35,7 +34,6 @@ struct TraceInfo { compiled_info = that.compiled_info; Py_XINCREF(compiled_info); - copied = true; return *this; } diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index d6084e339..72e72abc8 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -18,4 +18,18 @@ namespace mgb::imperative { MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); +namespace { namespace fastpathcopy { + auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + return inputs; + } + +OP_TRAIT_REG(FastpathCopy,FastpathCopy) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // fastpathcopy + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); + } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h index 609995a4e..ba85f3665 100644 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -35,4 +35,18 @@ struct GenericPyOp final : OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; }; +struct FastpathCopy final : OpDefImplBase { + FastpathCopy() = default; + + size_t hash() const override { + return mgb::hash(this->dyn_typeinfo()); + } + + bool is_same_st(const Hashable& rhs) const override { + return this->dyn_typeinfo() == rhs.dyn_typeinfo(); + } + + MGB_DYN_TYPE_OBJ_FINAL_DECL; +}; + } // namespace mgb::imperative -- GitLab