提交 d94a17d3 编写于 作者: M Megvii Engine Team

refactor(mge/jit): using static global_enable for apply ctx insted of global variable

GitOrigin-RevId: dd82b53faf55aa0d01ab181b54f48d63e143384c
上级 9eacb9df
import weakref import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable, Iterable from typing import Callable, Iterable
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
......
...@@ -1125,10 +1125,6 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): ...@@ -1125,10 +1125,6 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
if skip_tracing: if skip_tracing:
args = [
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args
]
unset_tracing() unset_tracing()
ret = RawTensor(value, dtype, device, False, name) ret = RawTensor(value, dtype, device, False, name)
set_tracing() set_tracing()
......
...@@ -50,29 +50,20 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) ...@@ -50,29 +50,20 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
#undef REGISTE_APPLY_FUNC #undef REGISTE_APPLY_FUNC
bool is_tracing = false; Tensor::flags_t ApplyContext::global_disable = 0;
Tensor::flags_t ApplyContext::global_enable = 0;
#define SET_UNSET_PROP(mode) \
void set_##mode() { \
is_##mode = true; \
} \
void unset_##mode() { \
is_##mode = false; \
} \
SET_UNSET_PROP(tracing)
#undef SET_UNSET_PROP void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; }
void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; }
bool skip_tracing = false; bool skip_tracing = false;
Tensor::flags_t ApplyContext::global_disable = 0;
apply_result_t apply(ApplyContext& ctx) { apply_result_t apply(ApplyContext& ctx) {
// emulating scalar should be put to specific op's apply, e.g., // emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python // elementwise, reduce, typecvt. Currently it's still handled at python
// side. It could be move to C++ side if it has an impact on performance // side. It could be move to C++ side if it has an impact on performance
auto flags = ctx.flags & ~ApplyContext::global_disable; auto flags = ctx.flags & ~ApplyContext::global_disable;
flags = flags | ApplyContext::global_enable;
if (flags & Tensor::Flags::SCALAR) { if (flags & Tensor::Flags::SCALAR) {
// TODO: emulate scalar // TODO: emulate scalar
...@@ -190,10 +181,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -190,10 +181,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
} }
} }
if (is_tracing) {
ctx.flags |= Tensor::Flags::TRACE;
}
auto outputs = apply(ctx); auto outputs = apply(ctx);
size_t nout = outputs.size(); size_t nout = outputs.size();
auto ret = py::tuple(nout); auto ret = py::tuple(nout);
...@@ -255,7 +242,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -255,7 +242,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>(); if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>();
// const op // const op
if (is_const && is_tracing) { if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) {
auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
if (!py_ret) throw py::error_already_set(); if (!py_ret) throw py::error_already_set();
auto py_list = py::reinterpret_steal<py::list>(py_ret); auto py_list = py::reinterpret_steal<py::list>(py_ret);
......
...@@ -193,8 +193,9 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -193,8 +193,9 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
struct ApplyContext { struct ApplyContext {
static Tensor::flags_t global_disable; static Tensor::flags_t global_disable;
static Tensor::flags_t global_enable;
Tensor::flags_t flags; Tensor::flags_t flags = 0;
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
Tensor*const* args; Tensor*const* args;
size_t nargs; size_t nargs;
...@@ -236,14 +237,11 @@ decltype(auto) resolve_arrow(T&& p) { ...@@ -236,14 +237,11 @@ decltype(auto) resolve_arrow(T&& p) {
template <typename... Args> template <typename... Args>
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
ApplyContext ctx; ApplyContext ctx;
Tensor* arg_arr[] = {resolve_arrow(args)...}; Tensor* arg_arr[] = {resolve_arrow(args)...};
ctx.flags = (0 | ... | args->m_flags); ctx.flags = (0 | ... | args->m_flags);
ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0;
ctx.args = arg_arr; ctx.args = arg_arr;
ctx.nargs = sizeof...(args); ctx.nargs = sizeof...(args);
ctx.op = std::move(op); ctx.op = std::move(op);
...@@ -256,7 +254,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) ...@@ -256,7 +254,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
apply_result_t> { apply_result_t> {
ApplyContext ctx; ApplyContext ctx;
ctx.op = std::move(op); ctx.op = std::move(op);
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = tensors.size(); ctx.nargs = tensors.size();
Tensor* args[ctx.nargs]; Tensor* args[ctx.nargs];
ctx.args = args; ctx.args = args;
...@@ -270,7 +267,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) ...@@ -270,7 +267,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx; ApplyContext ctx;
ctx.op = std::move(op); ctx.op = std::move(op);
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = nargs; ctx.nargs = nargs;
ctx.args = args; ctx.args = args;
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
......
...@@ -28,12 +28,12 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -28,12 +28,12 @@ apply_result_t apply_trace(ApplyContext& ctx) {
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = py::cast(ctx.args[i]->m_var); args[i + 1] = py::cast(ctx.args[i]->m_var);
} }
py::object ret = py::reinterpret_steal<py::object>( py::object pyout = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!ret) throw py::error_already_set(); if (!pyout) throw py::error_already_set();
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret); auto tup = py::reinterpret_borrow<py::list>(pyout);
for (size_t i = 0; i < tup.size(); i++) { for (size_t i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode*>(); auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem)); outputs.emplace_back(std::make_shared<Tensor>(pitem));
...@@ -48,6 +48,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -48,6 +48,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
} }
auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr);
if (!pyout) throw py::error_already_set(); if (!pyout) throw py::error_already_set();
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_steal<py::list>(pyout); auto tup = py::reinterpret_steal<py::list>(pyout);
for (size_t i = 0; i < tup.size(); i++) { for (size_t i = 0; i < tup.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册