未验证 提交 3cc11a3d 编写于 作者: W WeiXin 提交者: GitHub

pylayer_op:release context after compute. (#32707)

上级 00268194
...@@ -63,15 +63,16 @@ std::shared_ptr<GradOpNode> CreateGradOpNode( ...@@ -63,15 +63,16 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
} }
} }
py::object PyLayerApply(const platform::Place& place, const py::object& cls, py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
const py::args args, const py::kwargs kwargs) { const py::args args, const py::kwargs kwargs) {
py::gil_scoped_acquire guard;
auto bk_function = cls.attr("_backward_function"); auto bk_function = cls.attr("_backward_function");
auto context = bk_function(); auto context = bk_function();
auto forward = cls.attr("forward"); auto forward = cls.attr("forward");
auto result_forward = forward(context, *args, **kwargs); auto result_forward = forward(context, *args, **kwargs);
std::shared_ptr<operators::PyLayerContext> py_layer_ctx = std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
std::make_shared<operators::PyLayerContext>(context.release().ptr()); std::make_shared<operators::PyLayerContext>(context.ptr());
// make inputs to varbase // make inputs to varbase
std::vector<std::shared_ptr<imperative::VarBase>> input_vars; std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
// process args,`input_vars` only collect `imperative::VarBase` // process args,`input_vars` only collect `imperative::VarBase`
......
...@@ -157,9 +157,12 @@ class PyLayerOpKernel : public framework::OpKernel<T> { ...@@ -157,9 +157,12 @@ class PyLayerOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto &op_ = ctx.GetOp(); auto &op_ = ctx.GetOp();
auto pylayer_op = dynamic_cast<const PyLayerOp *>(&op_); auto const_pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
if (pylayer_op) { if (const_pylayer_op) {
auto py_layer_context = pylayer_op->GetPyLayerContext(); auto pylayer_op = const_cast<PyLayerOp *>(const_pylayer_op);
// Release contex after executing the compute
auto py_layer_context = pylayer_op->ReleasePyLayerContext();
py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true); py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true);
auto &input_vars = ctx.MultiInputVar("X"); auto &input_vars = ctx.MultiInputVar("X");
auto output_vars = ctx.MultiOutputVar("Out"); auto output_vars = ctx.MultiOutputVar("Out");
......
...@@ -34,6 +34,10 @@ class PyLayerContext { ...@@ -34,6 +34,10 @@ class PyLayerContext {
PyLayerContext() = delete; PyLayerContext() = delete;
PyObject* GetMutableCtx() { return context_; } PyObject* GetMutableCtx() { return context_; }
~PyLayerContext() {
py::gil_scoped_acquire guard;
Py_XDECREF(context_);
}
private: private:
PyObject* context_; PyObject* context_;
...@@ -58,8 +62,11 @@ class PyLayerOp : public framework::OperatorWithKernel { ...@@ -58,8 +62,11 @@ class PyLayerOp : public framework::OperatorWithKernel {
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) { void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
py_context_ = py_context; py_context_ = py_context;
} }
const std::shared_ptr<PyLayerContext>& GetPyLayerContext() const { std::shared_ptr<PyLayerContext> ReleasePyLayerContext() {
return py_context_; auto temp = py_context_;
py_context_.reset();
VLOG(3) << "`py_context_` in the PyLayerOp is released.";
return temp;
} }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册