未验证 提交 c22c2c58 编写于 作者: W wanghuancoder 提交者: GitHub

refine pylayer (#42572)

* refine pylayer

* refine
上级 c3b7bc61
...@@ -106,8 +106,6 @@ GradNodePyLayer::operator()( ...@@ -106,8 +106,6 @@ GradNodePyLayer::operator()(
pybind11::detail::error_string().c_str())); pybind11::detail::error_string().c_str()));
} }
outputs_ = outputs;
VLOG(6) << "PyLayer backward function finish..."; VLOG(6) << "PyLayer backward function finish...";
PyObject* outputs_tuple = nullptr; PyObject* outputs_tuple = nullptr;
...@@ -165,6 +163,9 @@ GradNodePyLayer::operator()( ...@@ -165,6 +163,9 @@ GradNodePyLayer::operator()(
if (!PyTuple_Check(outputs)) { if (!PyTuple_Check(outputs)) {
Py_XDECREF(outputs_tuple); Py_XDECREF(outputs_tuple);
} }
Py_XDECREF(outputs);
Py_XDECREF(ctx_);
ctx_ = nullptr;
return grad_out; return grad_out;
} }
......
...@@ -32,10 +32,7 @@ class GradNodePyLayer : public GradNodeBase { ...@@ -32,10 +32,7 @@ class GradNodePyLayer : public GradNodeBase {
ctx_ = ctx; ctx_ = ctx;
} }
~GradNodePyLayer() override { ~GradNodePyLayer() override { Py_XDECREF(ctx_); };
Py_DECREF(ctx_);
Py_XDECREF(outputs_);
};
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>, virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize> kSlotSmallVectorSize>
...@@ -50,9 +47,6 @@ class GradNodePyLayer : public GradNodeBase { ...@@ -50,9 +47,6 @@ class GradNodePyLayer : public GradNodeBase {
return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name); return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name);
} }
// for paddle.grad get result
PyObject* GetMutableOutputs() { return outputs_; }
void SaveForwardOutputsMeta( void SaveForwardOutputsMeta(
const std::vector<std::vector<paddle::experimental::Tensor*>>& const std::vector<std::vector<paddle::experimental::Tensor*>>&
outputs_tensor) { outputs_tensor) {
...@@ -81,7 +75,6 @@ class GradNodePyLayer : public GradNodeBase { ...@@ -81,7 +75,6 @@ class GradNodePyLayer : public GradNodeBase {
private: private:
PyObject* ctx_{nullptr}; PyObject* ctx_{nullptr};
PyObject* outputs_{nullptr};
std::vector<std::vector<phi::DenseTensorMeta>> forward_outputs_meta_; std::vector<std::vector<phi::DenseTensorMeta>> forward_outputs_meta_;
std::vector<std::vector<paddle::platform::Place>> forward_outputs_place_; std::vector<std::vector<paddle::platform::Place>> forward_outputs_place_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册