提交 c8d1a8e9 编写于 作者: M minqiyang

Change var_ and grad_ to shared_ptr

上级 0601f5c4
...@@ -114,7 +114,7 @@ class Autograd { ...@@ -114,7 +114,7 @@ class Autograd {
} }
}; };
framework::LoDTensor& VarBase::Grad() { framework::LoDTensor& VarBase::GradValue() {
VLOG(3) << "get var grad " << var_desc_->Name(); VLOG(3) << "get var grad " << var_desc_->Name();
return *(grads_->var_->GetMutable<framework::LoDTensor>()); return *(grads_->var_->GetMutable<framework::LoDTensor>());
} }
......
...@@ -109,7 +109,7 @@ class VarBase { ...@@ -109,7 +109,7 @@ class VarBase {
void RunBackward(); void RunBackward();
framework::LoDTensor& Grad(); framework::LoDTensor& GradValue();
inline std::string GradName() const { inline std::string GradName() const {
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -123,8 +123,9 @@ class VarBase { ...@@ -123,8 +123,9 @@ class VarBase {
int pre_op_out_idx_; int pre_op_out_idx_;
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
framework::Variable* var_;
VarBase* grads_; std::shared_ptr<framework::Variable> var_;
std::shared_ptr<VarBase> grads_;
bool stop_gradient_; bool stop_gradient_;
}; };
......
...@@ -74,10 +74,10 @@ class Tracer { ...@@ -74,10 +74,10 @@ class Tracer {
for (auto it : op->input_vars_) { for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
for (VarBase* inp : it.second) { for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(inp->var_.get(), "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name()); op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_); invars.push_back(inp->var_.get());
vars[inp->var_desc_->Name()] = inp; vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) { if (inp->pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_); op->pre_ops_[it.first].push_back(inp->pre_op_);
...@@ -96,7 +96,7 @@ class Tracer { ...@@ -96,7 +96,7 @@ class Tracer {
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
outvars.push_back(out->var_); outvars.push_back(out->var_.get());
vars[out->var_desc_->Name()] = out; vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name()); framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
...@@ -143,13 +143,13 @@ class Tracer { ...@@ -143,13 +143,13 @@ class Tracer {
if (var_it == grad_to_var->end()) { if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar); auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.push_back(fwd_var_it->second->var_.get());
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_); InitVar(var->var_.get(), var->grads_->var_.get());
} }
grad_in_vars.push_back(var->grads_->var_); grad_in_vars.push_back(var->grads_->var_.get());
} }
} }
} }
...@@ -162,9 +162,9 @@ class Tracer { ...@@ -162,9 +162,9 @@ class Tracer {
PADDLE_ENFORCE(var_it != grad_to_var->end()); PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_); InitVar(var->var_.get(), var->grads_->var_.get());
} }
grad_out_vars.push_back(var->grads_->var_); grad_out_vars.push_back(var->grads_->var_.get());
} }
} }
} }
......
...@@ -132,11 +132,12 @@ PYBIND11_MODULE(core, m) { ...@@ -132,11 +132,12 @@ PYBIND11_MODULE(core, m) {
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
.def("_grad", &imperative::VarBase::Grad) .def("_grad_value", &imperative::VarBase::GradValue)
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; }, [](const imperative::VarBase &self) { return self.grads_.get(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("value", [](const imperative::VarBase &self) { return self.var_; }, .def("value",
[](const imperative::VarBase &self) { return self.var_.get(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property( .def_property(
"desc", "desc",
......
...@@ -379,7 +379,7 @@ class Variable(object): ...@@ -379,7 +379,7 @@ class Variable(object):
self._ivar._run_backward() self._ivar._run_backward()
def _gradient(self): def _gradient(self):
return np.array(self._ivar._grad()) return np.array(self._ivar._grad_value())
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
......
...@@ -46,7 +46,6 @@ def to_variable(value, block=None): ...@@ -46,7 +46,6 @@ def to_variable(value, block=None):
shape=value.shape, shape=value.shape,
dtype=value.dtype) dtype=value.dtype)
var = py_var._ivar.value() var = py_var._ivar.value()
print(type(var))
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set(value, core.CPUPlace()) tensor.set(value, core.CPUPlace())
return py_var return py_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册