From 687171d22b14ba37cac7005af7681c354c16fc00 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 10 Jan 2019 22:07:53 +0800 Subject: [PATCH] Move from shared_ptr to raw pointer test=develop --- paddle/fluid/imperative/layer.h | 14 +++++++++++--- paddle/fluid/imperative/tracer.cc | 16 ++++++++-------- paddle/fluid/pybind/pybind.cc | 5 ++--- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 505056403..67b59d3a3 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -105,7 +105,15 @@ class VarBase { grads_(stop_gradient ? nullptr : new VarBase(true)), stop_gradient_(stop_gradient) {} - virtual ~VarBase() {} + virtual ~VarBase() { + if (var_) { + delete var_; + } + + if (grads_) { + delete grads_; + } + } void RunBackward(); @@ -124,8 +132,8 @@ class VarBase { framework::VarDesc* var_desc_; - std::shared_ptr var_; - std::shared_ptr grads_; + framework::Variable* var_; + VarBase* grads_; bool stop_gradient_; }; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 8e617e008..ead1ed5e3 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -58,10 +58,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, for (auto it : op->input_vars_) { auto& invars = invars_map[it.first]; for (VarBase* inp : it.second) { - PADDLE_ENFORCE_NOT_NULL(inp->var_.get(), "op %s input %s nullptr", + PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->op_desc_->Type(), inp->var_desc_->Name()); - invars.push_back(inp->var_.get()); + invars.push_back(inp->var_); vars[inp->var_desc_->Name()] = inp; if (inp->pre_op_) { op->pre_ops_[it.first].push_back(inp->pre_op_); @@ -80,7 +80,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, const std::vector& outputs = it.second; for (size_t i = 0; i < outputs.size(); ++i) { VarBase* out = outputs[i]; - outvars.push_back(out->var_.get()); + outvars.push_back(out->var_); vars[out->var_desc_->Name()] = out; framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); @@ -127,13 +127,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, if (var_it == grad_to_var->end()) { auto fwd_var_it = vars.find(grad_invar); PADDLE_ENFORCE(fwd_var_it != vars.end()); - grad_in_vars.push_back(fwd_var_it->second->var_.get()); + grad_in_vars.push_back(fwd_var_it->second->var_); } else { VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_.get(), var->grads_->var_.get()); + InitVar(var->var_, var->grads_->var_); } - grad_in_vars.push_back(var->grads_->var_.get()); + grad_in_vars.push_back(var->grads_->var_); } } } @@ -146,9 +146,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, PADDLE_ENFORCE(var_it != grad_to_var->end()); VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_.get(), var->grads_->var_.get()); + InitVar(var->var_, var->grads_->var_); } - grad_out_vars.push_back(var->grads_->var_.get()); + grad_out_vars.push_back(var->grads_->var_); } } } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index aee530036..d97e9e87a 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -133,10 +133,9 @@ PYBIND11_MODULE(core, m) { .def("_grad_name", &imperative::VarBase::GradName) .def("_grad_value", &imperative::VarBase::GradValue) .def("_grad_ivar", - [](const imperative::VarBase &self) { return self.grads_.get(); }, + [](const imperative::VarBase &self) { return self.grads_; }, py::return_value_policy::reference) - .def("value", - [](const imperative::VarBase &self) { return self.var_.get(); }, + .def("value", [](const imperative::VarBase &self) { return self.var_; }, py::return_value_policy::reference) .def_property( "desc", -- GitLab