提交 d1c8cd26 编写于 作者: X xiongkun

function

上级 7aabdfd9
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <code.h> #include <code.h>
#include <frameobject.h> #include <frameobject.h>
#include <internal/pycore_frame.h>
#include <object.h> #include <object.h>
#include <pystate.h> #include <pystate.h>
...@@ -36,6 +37,12 @@ namespace py = pybind11; ...@@ -36,6 +37,12 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
#if PY_VERSION_HEX >= 0x030b0000
typedef _PyInterpreterFrame FrameObject;
#else
typedef PyFrameObject FrameObject;
#endif
#define unlikely(x) __builtin_expect((x), 0) #define unlikely(x) __builtin_expect((x), 0)
// Use static variable to save customed eval hook. // Use static variable to save customed eval hook.
...@@ -56,7 +63,7 @@ inline static void eval_frame_callback_set(PyObject *obj) { ...@@ -56,7 +63,7 @@ inline static void eval_frame_callback_set(PyObject *obj) {
// call python default eval frame to interpret current frame. // call python default eval frame to interpret current frame.
inline static PyObject *eval_frame_default(PyThreadState *tstate, inline static PyObject *eval_frame_default(PyThreadState *tstate,
PyFrameObject *frame, FrameObject *frame,
int throw_flag) { int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000 #if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) { if (tstate == NULL) {
...@@ -71,7 +78,7 @@ inline static PyObject *eval_frame_default(PyThreadState *tstate, ...@@ -71,7 +78,7 @@ inline static PyObject *eval_frame_default(PyThreadState *tstate,
// Start a new frame and run code in this frame. // Start a new frame and run code in this frame.
// Execute a piece of code by default frame-hook. // Execute a piece of code by default frame-hook.
inline static PyObject *eval_custom_code(PyThreadState *tstate, inline static PyObject *eval_custom_code(PyThreadState *tstate,
PyFrameObject *frame, FrameObject *frame,
PyCodeObject *code, PyCodeObject *code,
int throw_flag) { int throw_flag) {
Py_ssize_t ncells = 0; Py_ssize_t ncells = 0;
...@@ -79,18 +86,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate, ...@@ -79,18 +86,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
Py_ssize_t nlocals_new = code->co_nlocals; Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals; Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
if ((code->co_flags & CO_NOFREE) == 0) { #if PY_VERSION_HEX >= 0x030b0000
ncells = PyTuple_GET_SIZE(code->co_cellvars); ncells = code->co_ncellvars;
nfrees = PyTuple_GET_SIZE(code->co_freevars); nfrees = code->co_nfreevars;
} #else
ncells = PyTuple_GET_SIZE(code->co_cellvars);
nfrees = PyTuple_GET_SIZE(code->co_freevars);
#endif
PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL); PyFrameObject *shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
if (shadow == NULL) { if (shadow == NULL) {
return NULL; return NULL;
} }
#if PY_VERSION_HEX >= 0x030b0000
PyObject **fastlocals_old = frame->localsplus;
PyObject **fastlocals_new = shadow->f_frame->localsplus;
#else
PyObject **fastlocals_old = frame->f_localsplus; PyObject **fastlocals_old = frame->f_localsplus;
PyObject **fastlocals_new = shadow->f_localsplus; PyObject **fastlocals_new = shadow->f_localsplus;
#endif
for (Py_ssize_t i = 0; i < nlocals_old; i++) { for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]); Py_XINCREF(fastlocals_old[i]);
...@@ -102,18 +117,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate, ...@@ -102,18 +117,26 @@ inline static PyObject *eval_custom_code(PyThreadState *tstate,
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i]; fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
} }
#if PY_VERSION_HEX >= 0x030b0000
PyObject *result = eval_frame_default(tstate, shadow->f_frame, throw_flag);
#else
PyObject *result = eval_frame_default(tstate, shadow, throw_flag); PyObject *result = eval_frame_default(tstate, shadow, throw_flag);
#endif
Py_DECREF(shadow); Py_DECREF(shadow);
return result; return result;
} }
static PyObject *_custom_eval_frame(PyThreadState *tstate, static PyObject *_custom_eval_frame(PyThreadState *tstate,
PyFrameObject *frame, FrameObject *frame,
int throw_flag, int throw_flag,
PyObject *callback) { PyObject *callback) {
// https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details // https://peps.python.org/pep-0558/#fast-locals-proxy-implementation-details
// https://devguide.python.org/internals/interpreter/#all-sorts-of-variables // https://devguide.python.org/internals/interpreter/#all-sorts-of-variables
#if PY_VERSION_HEX >= 0x030b0000
if (PyFrame_FastToLocalsWithError(frame->frame_obj) < 0) {
#else
if (PyFrame_FastToLocalsWithError(frame) < 0) { if (PyFrame_FastToLocalsWithError(frame) < 0) {
#endif
return NULL; return NULL;
} }
...@@ -167,7 +190,7 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, ...@@ -167,7 +190,7 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate,
} }
static PyObject *_custom_eval_frame_shim(PyThreadState *tstate, static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame, FrameObject *frame,
int throw_flag) { int throw_flag) {
PyObject *callback = eval_frame_callback_get(); PyObject *callback = eval_frame_callback_get();
...@@ -180,12 +203,12 @@ static PyObject *_custom_eval_frame_shim(PyThreadState *tstate, ...@@ -180,12 +203,12 @@ static PyObject *_custom_eval_frame_shim(PyThreadState *tstate,
#if PY_VERSION_HEX >= 0x03090000 #if PY_VERSION_HEX >= 0x03090000
static PyObject *custom_eval_frame_shim(PyThreadState *tstate, static PyObject *custom_eval_frame_shim(PyThreadState *tstate,
PyFrameObject *frame, FrameObject *frame,
int throw_flag) { int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag); return _custom_eval_frame_shim(tstate, frame, throw_flag);
} }
#else #else
static PyObject *custom_eval_frame_shim(PyFrameObject *frame, int throw_flag) { static PyObject *custom_eval_frame_shim(FrameObject *frame, int throw_flag) {
PyThreadState *tstate = PyThreadState_GET(); PyThreadState *tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag); return _custom_eval_frame_shim(tstate, frame, throw_flag);
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import collections import collections
import unittest import unittest
from sys import version_info
import paddle import paddle
...@@ -27,6 +28,12 @@ class TestEvalFrame(unittest.TestCase): ...@@ -27,6 +28,12 @@ class TestEvalFrame(unittest.TestCase):
pass pass
def test_eval_frame(self): def test_eval_frame(self):
if version_info.major != 3 or (
version_info.minor <= 8 or version_info.minor >= 11
):
print("skip test_eval_frame, current only support 3.8 - 3.10")
return
CustomCode = collections.namedtuple( CustomCode = collections.namedtuple(
"CustomCode", ["code", "disable_eval_frame"] "CustomCode", ["code", "disable_eval_frame"]
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册