提交 687171d2 编写于 作者: M minqiyang

Move from shared_ptr to raw pointer

test=develop
上级 80197fac
...@@ -105,7 +105,15 @@ class VarBase { ...@@ -105,7 +105,15 @@ class VarBase {
grads_(stop_gradient ? nullptr : new VarBase(true)), grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient) {} stop_gradient_(stop_gradient) {}
virtual ~VarBase() {} virtual ~VarBase() {
if (var_) {
delete var_;
}
if (grads_) {
delete grads_;
}
}
void RunBackward(); void RunBackward();
...@@ -124,8 +132,8 @@ class VarBase { ...@@ -124,8 +132,8 @@ class VarBase {
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
std::shared_ptr<framework::Variable> var_; framework::Variable* var_;
std::shared_ptr<VarBase> grads_; VarBase* grads_;
bool stop_gradient_; bool stop_gradient_;
}; };
......
...@@ -58,10 +58,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -58,10 +58,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : op->input_vars_) { for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
for (VarBase* inp : it.second) { for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_.get(), "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name()); op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_.get()); invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) { if (inp->pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_); op->pre_ops_[it.first].push_back(inp->pre_op_);
...@@ -80,7 +80,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -80,7 +80,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
outvars.push_back(out->var_.get()); outvars.push_back(out->var_);
vars[out->var_desc_->Name()] = out; vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
...@@ -127,13 +127,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -127,13 +127,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if (var_it == grad_to_var->end()) { if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar); auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_.get()); grad_in_vars.push_back(fwd_var_it->second->var_);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_.get(), var->grads_->var_.get()); InitVar(var->var_, var->grads_->var_);
} }
grad_in_vars.push_back(var->grads_->var_.get()); grad_in_vars.push_back(var->grads_->var_);
} }
} }
} }
...@@ -146,9 +146,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -146,9 +146,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(var_it != grad_to_var->end()); PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_.get(), var->grads_->var_.get()); InitVar(var->var_, var->grads_->var_);
} }
grad_out_vars.push_back(var->grads_->var_.get()); grad_out_vars.push_back(var->grads_->var_);
} }
} }
} }
......
...@@ -133,10 +133,9 @@ PYBIND11_MODULE(core, m) { ...@@ -133,10 +133,9 @@ PYBIND11_MODULE(core, m) {
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
.def("_grad_value", &imperative::VarBase::GradValue) .def("_grad_value", &imperative::VarBase::GradValue)
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_.get(); }, [](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("value", .def("value", [](const imperative::VarBase &self) { return self.var_; },
[](const imperative::VarBase &self) { return self.var_.get(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property( .def_property(
"desc", "desc",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册