You need to sign in or sign up before continuing.
未验证 提交 7b1695af 编写于 作者: X xiongkun 提交者: GitHub

[Dy2static-Fallback] add set_eval_frame function in pybind. (#52006)

* [Dy2static-Fallback] add set_eval_frame function in pybind.
1. add set_eval_frame function in pybind.

* add unittest for eval frame hooker.

* [support py38]

* fix-GeneratorExit error in eval frame hooker

* support python == 3.9

* support 3.10

* fix some comments
上级 2d0c6948
......@@ -14,20 +14,234 @@ limitations under the License. */
#include "paddle/fluid/pybind/jit.h"
#include <Python.h>
#include <code.h>
#include <frameobject.h>
#include <object.h>
#include <pystate.h>
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/place.h"
#include "glog/logging.h"
#include "paddle/fluid/jit/function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/utils/pybind.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
#define unlikely(x) __builtin_expect((x), 0)
// Use static variable to save customed eval hook.
static Py_tss_t eval_frame_callback_key = {0, 0};
inline static PyObject *eval_frame_callback_get(void) {
void *result = PyThread_tss_get(&eval_frame_callback_key);
if (unlikely(result == NULL)) {
Py_RETURN_NONE;
} else {
return reinterpret_cast<PyObject *>(result);
}
}
inline static void eval_frame_callback_set(PyObject *obj) {
PyThread_tss_set(&eval_frame_callback_key, obj);
}
// call python default eval frame to interpret current frame.
inline static PyObject *eval_frame_default(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) {
tstate = PyThreadState_GET();
}
return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
#else
return _PyEval_EvalFrameDefault(frame, throw_flag);
#endif
}
// Start a new frame and run code in this frame.
// Execute a piece of code by default frame-hook.
inline static PyObject *eval_custom_code(PyThreadState *tstate,
PyFrameObject *frame,
PyCodeObject *code,
int throw_flag) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
if ((code->co_flags & CO_NOFREE) == 0) {
ncells = PyTuple_GET_SIZE(code->co_cellvars);
nfrees = PyTuple_GET_SIZE(code->co_freevars);
}
PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
if (shadow == NULL) {
return NULL;
}
PyObject **fastlocals_old = frame->f_localsplus;
PyObject **fastlocals_new = shadow->f_localsplus;
for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[i] = fastlocals_old[i];
}
for (Py_ssize_t i = 0; i < ncells + nfrees; i++) {
Py_XINCREF(fastlocals_old[nlocals_old + i]);
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}
PyObject *result = eval_frame_default(tstate, shadow, throw_flag);
Py_DECREF(shadow);
return result;
}
static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag,
PyObject *callback) {
// https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details
// https://devguide.python.org/internals/interpreter/#all-sorts-of-variables
if (PyFrame_FastToLocalsWithError(frame) < 0) {
return NULL;
}
// NOTE:(xiongkun): Handle GeneratorExit exception: (Spend a day)
// In Python, gen close is also a Python function call that will enter this
// function with GeneratorExit set, which will cause the PyObject_CallObject
// raise SystemError. So we disable the custom behavior for GeneratorExit. def
// func():
// iter = iter([1, 2, 3])
// for i in iter:
// return i # <--- Early return, cause a GeneratorExit thrown,
// # <--- which Cause the PyObject_CallObject raise
// SystemError.
if (PyErr_ExceptionMatches(PyExc_GeneratorExit)) {
return eval_frame_default(tstate, frame, throw_flag);
}
// We don't run the current custom_eval_frame behavior for guards.
// So we temporarily set the callback to Py_None to drive the correct behavior
// in the shim.
eval_frame_callback_set(Py_None);
PyObject *args = Py_BuildValue("(O)", frame);
PyObject *result = PyObject_CallObject(callback, args);
// result: GuardedCode
if (result == NULL) {
// internal exception
return NULL;
} else if (result != Py_None) {
// NOTE: Cache is not supported now
PyCodeObject *code = reinterpret_cast<PyCodeObject *>(
PyObject_GetAttrString(result, "code"));
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, code, throw_flag);
} else {
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_frame_default(tstate, frame, throw_flag);
}
}
static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
PyObject *callback = eval_frame_callback_get();
if (callback == Py_None) {
return eval_frame_default(tstate, frame, throw_flag);
}
return _custom_eval_frame(tstate, frame, throw_flag, callback);
}
#if PY_VERSION_HEX >= 0x03090000
static PyObject *custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#else
static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) {
PyThreadState *tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#endif
static PyObject *set_eval_frame(PyObject *new_callback, PyThreadState *tstate) {
// Change the eval frame callback and return the old one
// - None: disables: diable custom callback.
// - Python callable(): enables custom callback.
// NOTE: Cache is not supported now
PyObject *old_callback = eval_frame_callback_get();
#if PY_VERSION_HEX >= 0x03090000
auto *old_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
#else
// Function pointer.
_PyFrameEvalFunction old_eval_frame = tstate->interp->eval_frame;
#endif
// NOTE: multi-threading is not supported now
if (old_callback != Py_None && new_callback == Py_None) {
if (old_eval_frame != &_PyEval_EvalFrameDefault) {
VLOG(7) << "set _PyEval_EvalFrameDefault";
#if PY_VERSION_HEX >= 0x03090000
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&_PyEval_EvalFrameDefault);
#else
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
#endif
}
} else if (old_callback == Py_None && new_callback != Py_None) {
if (old_eval_frame != &custom_eval_frame_shim) {
VLOG(7) << "set custom_eval_frame_shim";
#if PY_VERSION_HEX >= 0x03090000
_PyInterpreterState_SetEvalFrameFunc(tstate->interp,
&custom_eval_frame_shim);
#else
tstate->interp->eval_frame = &custom_eval_frame_shim;
#endif
}
}
Py_INCREF(new_callback);
eval_frame_callback_set(new_callback);
return old_callback;
}
static PyObject *set_eval_frame_py(PyObject *callback) {
if (callback != Py_None && !PyCallable_Check(callback)) {
VLOG(7) << "callback is not a callable or none, invalid arguments.";
RETURN_PY_NONE
}
return set_eval_frame(callback, PyThreadState_GET());
}
PyMODINIT_FUNC PyInit__eval_frame(void) {
int result = PyThread_tss_create(&eval_frame_callback_key);
VLOG(7) << "Set PyThread_tss_create return: " << result;
Py_INCREF(Py_None);
eval_frame_callback_set(Py_None);
return NULL;
}
PyTypeObject *g_jit_function_pytype = nullptr;
using Variable = paddle::framework::Variable;
......@@ -58,5 +272,18 @@ void BindJit(pybind11::module *m) {
});
}
void BindEvalFrame(pybind11::module *m) {
PyInit__eval_frame();
m->def(
"set_eval_frame",
[](const py::object &py_func) {
VLOG(5) << "start call set_eval_frame_py.";
auto ret = set_eval_frame_py(py_func.ptr());
auto obj = py::reinterpret_borrow<py::object>(ret);
return obj;
},
py::arg("callback"));
}
} // namespace pybind
} // namespace paddle
......@@ -22,10 +22,102 @@ limitations under the License. */
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
// see https://bugs.python.org/issue35886
// If py_version==3.8.*, we need to redefine _PyEvalFrameFunc and the
// related functions and structs.
#if PY_VERSION_HEX >= 0x03080000 && PY_VERSION_HEX < 0x3090000
typedef PyObject *(*_PyFrameEvalFunction)(struct _frame *, int);
struct _warnings_runtime_state {
/* Both 'filters' and 'onceregistry' can be set in warnings.py;
get_warnings_attr() will reset these variables accordingly. */
PyObject *filters; /* List */
PyObject *once_registry; /* Dict */
PyObject *default_action; /* String */
long filters_version; // NOLINT
};
struct _is {
struct _is *next;
struct _ts *tstate_head;
int64_t id;
int64_t id_refcount;
int requires_idref;
PyThread_type_lock id_mutex;
int finalizing;
PyObject *modules;
PyObject *modules_by_index;
PyObject *sysdict;
PyObject *builtins;
PyObject *importlib;
/* Used in Python/sysmodule.c. */
int check_interval;
/* Used in Modules/_threadmodule.c. */
long num_threads; // NOLINT
/* Support for runtime thread stack size tuning.
A value of 0 means using the platform's default stack size
or the size specified by the THREAD_STACK_SIZE macro. */
/* Used in Python/thread.c. */
size_t pythread_stacksize;
PyObject *codec_search_path;
PyObject *codec_search_cache;
PyObject *codec_error_registry;
int codecs_initialized;
/* fs_codec.encoding is initialized to NULL.
Later, it is set to a non-NULL string by _PyUnicode_InitEncodings(). */
struct {
char *encoding; /* Filesystem encoding (encoded to UTF-8) */
char *errors; /* Filesystem errors (encoded to UTF-8) */
_Py_error_handler error_handler;
} fs_codec;
PyConfig config;
#ifdef HAVE_DLOPEN
int dlopenflags;
#endif
PyObject *dict; /* Stores per-interpreter state */
PyObject *builtins_copy;
PyObject *import_func;
/* Initialized to PyEval_EvalFrameDefault(). */
_PyFrameEvalFunction eval_frame;
Py_ssize_t co_extra_user_count;
freefunc co_extra_freefuncs[MAX_CO_EXTRA_USERS];
#ifdef HAVE_FORK
PyObject *before_forkers;
PyObject *after_forkers_parent;
PyObject *after_forkers_child;
#endif
/* AtExit module */
void (*pyexitfunc)(PyObject *);
PyObject *pyexitmodule;
uint64_t tstate_next_unique_id;
struct _warnings_runtime_state warnings;
PyObject *audit_hooks;
};
#endif
namespace paddle {
namespace pybind {
void BindJit(pybind11::module* m);
void BindJit(pybind11::module *m);
void BindEvalFrame(pybind11::module *m);
} // namespace pybind
} // namespace paddle
......@@ -676,6 +676,7 @@ PYBIND11_MODULE(libpaddle, m) {
BindCudaStream(&m);
BindXpuStream(&m);
BindJit(&m);
BindEvalFrame(&m);
BindCustomDevicePy(&m);
// Not used, just make sure cpu_info.cc is linked.
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import unittest
import paddle
class TestEvalFrame(unittest.TestCase):
def setUp(self):
self.x = paddle.to_tensor(2).astype('int')
def tearDown(self):
pass
def test_eval_frame(self):
CustomCode = collections.namedtuple("CustomCode", ["code"])
def mul(a, b):
return a * b
code = CustomCode(mul.__code__)
def callback(frame_obj):
# Do your callback function here and return a object with `.code`
if frame_obj.f_code.co_name == "add":
return code
return CustomCode(code=frame_obj.f_code) # do nothing.
def add(a, b):
return a + b
x = 1
y = 2
paddle.fluid.core.set_eval_frame(callback)
assert add(x, y) == 2, "should be 2"
paddle.fluid.core.set_eval_frame(None)
assert add(x, y) == 3, "should be 3"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册