From d94a17d3d180278f08c0b23301fc663cd43edd36 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 13 May 2021 12:53:01 +0800 Subject: [PATCH] refactor(mge/jit): using static global_enable for apply ctx insted of global variable GitOrigin-RevId: dd82b53faf55aa0d01ab181b54f48d63e143384c --- .../python/megengine/autodiff/grad_manager.py | 2 -- imperative/python/megengine/jit/tracing.py | 4 --- imperative/python/src/tensor.cpp | 25 +++++-------------- imperative/python/src/tensor.h | 8 ++---- imperative/python/src/trace.cpp | 7 +++--- 5 files changed, 12 insertions(+), 34 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 72c689fd4..d2d73f3dd 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,6 +1,4 @@ import weakref -from collections import defaultdict -from contextlib import contextmanager from typing import Callable, Iterable from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 53f91c60c..a8c222f89 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -1125,10 +1125,6 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): if skip_tracing: - args = [ - RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x - for x in args - ] unset_tracing() ret = RawTensor(value, dtype, device, False, name) set_tracing() diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 1dfddc0fb..8d44ae05c 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -50,29 +50,20 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) #undef REGISTE_APPLY_FUNC -bool is_tracing = false; - -#define SET_UNSET_PROP(mode) \ - void set_##mode() { \ - is_##mode = true; \ - } \ - void unset_##mode() { \ - is_##mode = false; \ - } \ - -SET_UNSET_PROP(tracing) +Tensor::flags_t ApplyContext::global_disable = 0; +Tensor::flags_t ApplyContext::global_enable = 0; -#undef SET_UNSET_PROP +void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; } +void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; } bool skip_tracing = false; -Tensor::flags_t ApplyContext::global_disable = 0; - apply_result_t apply(ApplyContext& ctx) { // emulating scalar should be put to specific op's apply, e.g., // elementwise, reduce, typecvt. Currently it's still handled at python // side. It could be move to C++ side if it has an impact on performance auto flags = ctx.flags & ~ApplyContext::global_disable; + flags = flags | ApplyContext::global_enable; if (flags & Tensor::Flags::SCALAR) { // TODO: emulate scalar @@ -190,10 +181,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje } } - if (is_tracing) { - ctx.flags |= Tensor::Flags::TRACE; - } - auto outputs = apply(ctx); size_t nout = outputs.size(); auto ret = py::tuple(nout); @@ -255,7 +242,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast(); // const op - if (is_const && is_tracing) { + if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) { auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); if (!py_ret) throw py::error_already_set(); auto py_list = py::reinterpret_steal(py_ret); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index fe9541155..89b505752 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -193,8 +193,9 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje struct ApplyContext { static Tensor::flags_t global_disable; + static Tensor::flags_t global_enable; - Tensor::flags_t flags; + Tensor::flags_t flags = 0; std::shared_ptr op; Tensor*const* args; size_t nargs; @@ -236,14 +237,11 @@ decltype(auto) resolve_arrow(T&& p) { template constexpr bool is_all_tensor_ptr = (... && std::is_same_v())), Tensor*>); -extern bool is_tracing; // FIXME: should use ApplyContext::global_enable - template , int> = 0> apply_result_t apply(std::shared_ptr op, Args&&... args) { ApplyContext ctx; Tensor* arg_arr[] = {resolve_arrow(args)...}; ctx.flags = (0 | ... | args->m_flags); - ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0; ctx.args = arg_arr; ctx.nargs = sizeof...(args); ctx.op = std::move(op); @@ -256,7 +254,6 @@ auto apply(std::shared_ptr op, T&& tensors) apply_result_t> { ApplyContext ctx; ctx.op = std::move(op); - ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; ctx.nargs = tensors.size(); Tensor* args[ctx.nargs]; ctx.args = args; @@ -270,7 +267,6 @@ auto apply(std::shared_ptr op, T&& tensors) inline auto apply(std::shared_ptr op, Tensor*const* args, size_t nargs) { ApplyContext ctx; ctx.op = std::move(op); - ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; ctx.nargs = nargs; ctx.args = args; for (size_t i = 0; i < nargs; ++i) { diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index 853a498fe..30ddb78bd 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -28,12 +28,12 @@ apply_result_t apply_trace(ApplyContext& ctx) { for (size_t i = 0; i < ctx.nargs; i++) { args[i + 1] = py::cast(ctx.args[i]->m_var); } - py::object ret = py::reinterpret_steal( + py::object pyout = py::reinterpret_steal( PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); - if (!ret) throw py::error_already_set(); + if (!pyout) throw py::error_already_set(); // assumption: python function always returns PyList - auto tup = py::reinterpret_borrow(ret); + auto tup = py::reinterpret_borrow(pyout); for (size_t i = 0; i < tup.size(); i++) { auto pitem = tup[i].cast(); outputs.emplace_back(std::make_shared(pitem)); @@ -48,6 +48,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { } auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); if (!pyout) throw py::error_already_set(); + // assumption: python function always returns PyList auto tup = py::reinterpret_steal(pyout); for (size_t i = 0; i < tup.size(); i++) { -- GitLab