未验证 提交 3cc11a3d 编写于 作者: W WeiXin 提交者: GitHub

pylayer_op:release context after compute. (#32707)

上级 00268194
......@@ -63,15 +63,16 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
}
}
py::object PyLayerApply(const platform::Place& place, const py::object& cls,
py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
const py::args args, const py::kwargs kwargs) {
py::gil_scoped_acquire guard;
auto bk_function = cls.attr("_backward_function");
auto context = bk_function();
auto forward = cls.attr("forward");
auto result_forward = forward(context, *args, **kwargs);
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
std::make_shared<operators::PyLayerContext>(context.release().ptr());
std::make_shared<operators::PyLayerContext>(context.ptr());
// make inputs to varbase
std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
// process args,`input_vars` only collect `imperative::VarBase`
......
......@@ -157,9 +157,12 @@ class PyLayerOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &op_ = ctx.GetOp();
auto pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
if (pylayer_op) {
auto py_layer_context = pylayer_op->GetPyLayerContext();
auto const_pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
if (const_pylayer_op) {
auto pylayer_op = const_cast<PyLayerOp *>(const_pylayer_op);
// Release contex after executing the compute
auto py_layer_context = pylayer_op->ReleasePyLayerContext();
py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true);
auto &input_vars = ctx.MultiInputVar("X");
auto output_vars = ctx.MultiOutputVar("Out");
......
......@@ -34,6 +34,10 @@ class PyLayerContext {
PyLayerContext() = delete;
PyObject* GetMutableCtx() { return context_; }
~PyLayerContext() {
py::gil_scoped_acquire guard;
Py_XDECREF(context_);
}
private:
PyObject* context_;
......@@ -58,8 +62,11 @@ class PyLayerOp : public framework::OperatorWithKernel {
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
py_context_ = py_context;
}
const std::shared_ptr<PyLayerContext>& GetPyLayerContext() const {
return py_context_;
std::shared_ptr<PyLayerContext> ReleasePyLayerContext() {
auto temp = py_context_;
py_context_.reset();
VLOG(3) << "`py_context_` in the PyLayerOp is released.";
return temp;
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册