提交 9eacb9df 编写于 作者: M Megvii Engine Team

refactor(mge/jit): remove is_compiled flag in cpp tensor

GitOrigin-RevId: 15f90af735e6e5e53d2ac823bd4b9382d8144144
上级 4e80cf52
...@@ -7,15 +7,11 @@ ...@@ -7,15 +7,11 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
set_cpp_apply_compiled_mode,
set_cpp_apply_const_compiled_mode,
set_cpp_apply_const_with_tracing, set_cpp_apply_const_with_tracing,
set_cpp_apply_with_tracing, set_cpp_apply_with_tracing,
) )
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import ( from .tracing import (
apply_compiled_mode,
apply_const_compiled_mode,
apply_const_with_tracing, apply_const_with_tracing,
apply_with_tracing, apply_with_tracing,
exclude_from_trace, exclude_from_trace,
...@@ -24,5 +20,3 @@ from .tracing import ( ...@@ -24,5 +20,3 @@ from .tracing import (
set_cpp_apply_with_tracing(apply_with_tracing) set_cpp_apply_with_tracing(apply_with_tracing)
set_cpp_apply_const_with_tracing(apply_const_with_tracing) set_cpp_apply_const_with_tracing(apply_const_with_tracing)
set_cpp_apply_compiled_mode(apply_compiled_mode)
set_cpp_apply_const_compiled_mode(apply_const_compiled_mode)
...@@ -12,20 +12,16 @@ import functools ...@@ -12,20 +12,16 @@ import functools
import itertools import itertools
import json import json
import os import os
import typing
import weakref
import numpy as np import numpy as np
from ..core._imperative_rt import GraphProfiler, common from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
TensorWeakRef, TensorWeakRef,
apply, apply,
set_compiled,
set_tracing, set_tracing,
skip_tracing, skip_tracing,
unset_compiled,
unset_tracing, unset_tracing,
) )
from ..core._imperative_rt.ops import ( from ..core._imperative_rt.ops import (
...@@ -394,7 +390,6 @@ class trace: ...@@ -394,7 +390,6 @@ class trace:
if self._untraced: if self._untraced:
self._init_trace(self._symbolic) self._init_trace(self._symbolic)
else: else:
set_compiled()
if self._graph is None: if self._graph is None:
self._compile() self._compile()
self._graph.execute() self._graph.execute()
...@@ -442,7 +437,6 @@ class trace: ...@@ -442,7 +437,6 @@ class trace:
self._tensor_remaps = None self._tensor_remaps = None
self._set_active(False) self._set_active(False)
set_symbolic_shape(self._save_symbolic_shape) set_symbolic_shape(self._save_symbolic_shape)
unset_compiled()
unset_tracing() unset_tracing()
def do_exit(): def do_exit():
...@@ -989,11 +983,6 @@ class trace: ...@@ -989,11 +983,6 @@ class trace:
raise RuntimeError("trace is not set with profiling=True") raise RuntimeError("trace is not set with profiling=True")
return json.loads(self._profiler.get()) return json.loads(self._profiler.get())
def __del__(self):
for x in self._tinfo:
if getattr(x, "bound_data", None):
x.bound_data = None
def trace(self, *args, **kwargs): def trace(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
"trace is deemed unbeneficial with the new " "trace is deemed unbeneficial with the new "
...@@ -1148,6 +1137,9 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): ...@@ -1148,6 +1137,9 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
def apply_with_tracing(op: OpDef, *args: RawTensor): def apply_with_tracing(op: OpDef, *args: RawTensor):
if active_trace._graph:
# if member _graph exits, then is_compiled
return apply_compiled_mode(op, *args)
if hasattr(op, "scope"): if hasattr(op, "scope"):
op.scope = AutoNaming.get_scope() op.scope = AutoNaming.get_scope()
if active_trace._symbolic: if active_trace._symbolic:
...@@ -1162,11 +1154,16 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): ...@@ -1162,11 +1154,16 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
if active_trace._graph:
return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device, name) outputs = apply_const_symbolic_mode(value, dtype, device, name)
else: else:
unset_tracing() unset_tracing()
outputs = (RawTensor(value, dtype, device, False, name),) outputs = RawTensor(value, dtype, device, False, name)
if np.array(value).ndim == 0:
setscalar(outputs)
outputs = (outputs,)
set_tracing() set_tracing()
active_trace._record_const(outputs) active_trace._record_const(outputs)
return list(outputs) return list(outputs)
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "./common.h" #include "./common.h"
#include "./ops.h" #include "./ops.h"
#include "megbrain/gopt/inference.h" #include "megbrain/gopt/inference.h"
#include "megbrain/imperative/ops/utility.h"
namespace py = pybind11; namespace py = pybind11;
......
...@@ -36,27 +36,21 @@ namespace mgb::imperative::python { ...@@ -36,27 +36,21 @@ namespace mgb::imperative::python {
interpreter::Interpreter::Channel* interpreter_for_py; interpreter::Interpreter::Channel* interpreter_for_py;
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
*cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode;
PyObject *cpp_apply_backward_varnode; PyObject *cpp_apply_backward_varnode;
#define REGISTE_APPLY_FUNC(mode) \
#define REGISTE_APPLY_FUNC(mode) \ void set_##mode(py::object pyf) { \
void set_##mode(py::object pyf) { \ mode = pyf.ptr(); \
mode = pyf.ptr(); \
} }
REGISTE_APPLY_FUNC(cpp_apply_with_tracing) REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_compiled_mode)
REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode)
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
#undef REGISTE_APPLY_FUNC #undef REGISTE_APPLY_FUNC
bool is_tracing = false; bool is_tracing = false;
bool is_compiled = false;
#define SET_UNSET_PROP(mode) \ #define SET_UNSET_PROP(mode) \
void set_##mode() { \ void set_##mode() { \
...@@ -67,7 +61,6 @@ bool is_compiled = false; ...@@ -67,7 +61,6 @@ bool is_compiled = false;
} \ } \
SET_UNSET_PROP(tracing) SET_UNSET_PROP(tracing)
SET_UNSET_PROP(compiled)
#undef SET_UNSET_PROP #undef SET_UNSET_PROP
...@@ -263,14 +256,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -263,14 +256,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// const op // const op
if (is_const && is_tracing) { if (is_const && is_tracing) {
PyObject *pyf; auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
if (is_compiled) {
pyf = cpp_apply_const_compiled_mode;
} else {
pyf = cpp_apply_const_with_tracing;
}
auto py_ret = PyObject_Call(pyf, tup.ptr(), nullptr);
if (!py_ret) throw py::error_already_set(); if (!py_ret) throw py::error_already_set();
auto py_list = py::reinterpret_steal<py::list>(py_ret); auto py_list = py::reinterpret_steal<py::list>(py_ret);
if (auto* t = try_cast(py_list[0].ptr())) { if (auto* t = try_cast(py_list[0].ptr())) {
...@@ -961,8 +947,6 @@ void init_tensor(py::module m) { ...@@ -961,8 +947,6 @@ void init_tensor(py::module m) {
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode);
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
m.attr("skip_tracing") = &skip_tracing; m.attr("skip_tracing") = &skip_tracing;
...@@ -979,8 +963,6 @@ void init_tensor(py::module m) { ...@@ -979,8 +963,6 @@ void init_tensor(py::module m) {
m.def("set_tracing", &set_tracing); m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing); m.def("unset_tracing", &unset_tracing);
m.def("set_compiled", &set_compiled);
m.def("unset_compiled", &unset_compiled);
} }
#undef MGE_PY_INTERFACE #undef MGE_PY_INTERFACE
......
...@@ -237,7 +237,6 @@ template <typename... Args> ...@@ -237,7 +237,6 @@ template <typename... Args>
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable extern bool is_tracing; // FIXME: should use ApplyContext::global_enable
extern bool is_compiled;
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
...@@ -282,7 +281,7 @@ inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { ...@@ -282,7 +281,7 @@ inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
void init_tensor(pybind11::module); void init_tensor(pybind11::module);
extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode; extern PyObject *cpp_apply_with_tracing;
extern PyObject *cpp_apply_backward_varnode; extern PyObject *cpp_apply_backward_varnode;
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
......
...@@ -22,7 +22,6 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -22,7 +22,6 @@ apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs; apply_result_t outputs;
if (ctx.backward) { if (ctx.backward) {
// reach here when compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args) // call megbrain_graph.py apply(BackwardGraph, *args)
auto args = py::tuple(ctx.nargs + 1); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op); args[0] = py::cast(ctx.op);
...@@ -42,27 +41,16 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -42,27 +41,16 @@ apply_result_t apply_trace(ApplyContext& ctx) {
return outputs; return outputs;
} }
PyObject* pyf;
if (is_compiled) {
// run apply in compiled mode, step 2, 3, etc
pyf = cpp_apply_compiled_mode;
} else {
// run first step, both symbolic and non symbolic
pyf = cpp_apply_with_tracing;
}
auto args = py::tuple(ctx.nargs + 1); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op); args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
} }
auto pyout = PyObject_Call(pyf, args.ptr(), nullptr); auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr);
if (!pyout) throw py::error_already_set(); if (!pyout) throw py::error_already_set();
auto ret = py::reinterpret_steal<py::object>(pyout);
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret); auto tup = py::reinterpret_steal<py::list>(pyout);
for (auto i = 0; i < tup.size(); i++) { for (size_t i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::try_cast(tup[i].ptr()); auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor); outputs.emplace_back(tw->m_tensor);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册