提交 687171d2 编写于 作者: M minqiyang

Move from shared_ptr to raw pointer

test=develop
上级 80197fac
......@@ -105,7 +105,15 @@ class VarBase {
grads_(stop_gradient ? nullptr : new VarBase(true)),
stop_gradient_(stop_gradient) {}
virtual ~VarBase() {}
virtual ~VarBase() {
if (var_) {
delete var_;
}
if (grads_) {
delete grads_;
}
}
void RunBackward();
......@@ -124,8 +132,8 @@ class VarBase {
framework::VarDesc* var_desc_;
std::shared_ptr<framework::Variable> var_;
std::shared_ptr<VarBase> grads_;
framework::Variable* var_;
VarBase* grads_;
bool stop_gradient_;
};
......
......@@ -58,10 +58,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first];
for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_.get(), "op %s input %s nullptr",
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_.get());
invars.push_back(inp->var_);
vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_);
......@@ -80,7 +80,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const std::vector<VarBase*>& outputs = it.second;
for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i];
outvars.push_back(out->var_.get());
outvars.push_back(out->var_);
vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
......@@ -127,13 +127,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_.get());
grad_in_vars.push_back(fwd_var_it->second->var_);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_.get(), var->grads_->var_.get());
InitVar(var->var_, var->grads_->var_);
}
grad_in_vars.push_back(var->grads_->var_.get());
grad_in_vars.push_back(var->grads_->var_);
}
}
}
......@@ -146,9 +146,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_.get(), var->grads_->var_.get());
InitVar(var->var_, var->grads_->var_);
}
grad_out_vars.push_back(var->grads_->var_.get());
grad_out_vars.push_back(var->grads_->var_);
}
}
}
......
......@@ -133,10 +133,9 @@ PYBIND11_MODULE(core, m) {
.def("_grad_name", &imperative::VarBase::GradName)
.def("_grad_value", &imperative::VarBase::GradValue)
.def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_.get(); },
[](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference)
.def("value",
[](const imperative::VarBase &self) { return self.var_.get(); },
.def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference)
.def_property(
"desc",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册