提交 d1c8cd26 编写于 作者: X xiongkun

function

上级 7aabdfd9
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <Python.h>
#include <code.h>
#include <frameobject.h>
#include <internal/pycore_frame.h>
#include <object.h>
#include <pystate.h>
......@@ -36,6 +37,12 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
#if PY_VERSION_HEX >= 0x030b0000
typedef _PyInterpreterFrame FrameObject;
#else
typedef PyFrameObject FrameObject;
#endif
#define unlikely(x) __builtin_expect((x), 0)
// Use static variable to save customed eval hook.
......@@ -56,7 +63,7 @@ inline static void eval_frame_callback_set(PyObject *obj) {
// call python default eval frame to interpret current frame.
inline static PyObject *eval_frame_default(PyThreadState *tstate,
PyFrameObject *frame,
FrameObject *frame,
int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) {
......@@ -71,7 +78,7 @@ inline static PyObject *eval_frame_default(PyThreadState *tstate,
// Start a new frame and run code in this frame.
// Execute a piece of code by default frame-hook.
inline static PyObject *eval_custom_code(PyThreadState *tstate,
PyFrameObject *frame,
FrameObject *frame,
PyCodeObject *code,
int throw_flag) {
Py_ssize_t ncells = 0;
......@@ -79,18 +86,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
if ((code->co_flags & CO_NOFREE) == 0) {
ncells = PyTuple_GET_SIZE(code->co_cellvars);
nfrees = PyTuple_GET_SIZE(code->co_freevars);
}
#if PY_VERSION_HEX >= 0x030b0000
ncells = code->co_ncellvars;
nfrees = code->co_nfreevars;
#else
ncells = PyTuple_GET_SIZE(code->co_cellvars);
nfrees = PyTuple_GET_SIZE(code->co_freevars);
#endif
PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
if (shadow == NULL) {
return NULL;
}
#if PY_VERSION_HEX >= 0x030b0000
PyObject **fastlocals_old = frame->localsplus;
PyObject **fastlocals_new = shadow->f_frame->localsplus;
#else
PyObject **fastlocals_old = frame->f_localsplus;
PyObject **fastlocals_new = shadow->f_localsplus;
#endif
for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]);
......@@ -102,18 +117,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}
#if PY_VERSION_HEX >= 0x030b0000
PyObject *result = eval_frame_default(tstate, shadow->f_frame, throw_flag);
#else
PyObject *result = eval_frame_default(tstate, shadow, throw_flag);
#endif
Py_DECREF(shadow);
return result;
}
static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyFrameObject *frame,
FrameObject *frame,
int throw_flag,
PyObject *callback) {
// https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details
// https://devguide.python.org/internals/interpreter/#all-sorts-of-variables
// https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details
// https://devguide.python.org/internals/interpreter/#all-sorts-of-variables
#if PY_VERSION_HEX >= 0x030b0000
if (PyFrame_FastToLocalsWithError(frame->frame_obj) < 0) {
#else
if (PyFrame_FastToLocalsWithError(frame) < 0) {
#endif
return NULL;
}
......@@ -167,7 +190,7 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
}
static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
FrameObject *frame,
int throw_flag) {
PyObject *callback = eval_frame_callback_get();
......@@ -180,12 +203,12 @@ static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
#if PY_VERSION_HEX >= 0x03090000
static PyObject *custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame,
FrameObject *frame,
int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#else
static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) {
static PyObject *custom_eval_frame_shim(FrameObject *frame, int throw_flag) {
PyThreadState *tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
......
......@@ -15,6 +15,7 @@
import collections
import unittest
from sys import version_info
import paddle
......@@ -27,6 +28,12 @@ class TestEvalFrame(unittest.TestCase):
pass
def test_eval_frame(self):
if version_info.major != 3 or (
version_info.minor <= 8 or version_info.minor >= 11
):
print("skip test_eval_frame, current only support 3.8 - 3.10")
return
CustomCode = collections.namedtuple(
"CustomCode", ["code", "disable_eval_frame"]
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册