提交 c8d1a8e9 编写于 作者: M minqiyang

Change var_ and grad_ to shared_ptr

上级 0601f5c4
......@@ -114,7 +114,7 @@ class Autograd {
}
};
framework::LoDTensor& VarBase::Grad() {
framework::LoDTensor& VarBase::GradValue() {
VLOG(3) << "get var grad " << var_desc_->Name();
return *(grads_->var_->GetMutable<framework::LoDTensor>());
}
......
......@@ -109,7 +109,7 @@ class VarBase {
void RunBackward();
framework::LoDTensor& Grad();
framework::LoDTensor& GradValue();
inline std::string GradName() const {
PADDLE_ENFORCE(
......@@ -123,8 +123,9 @@ class VarBase {
int pre_op_out_idx_;
framework::VarDesc* var_desc_;
framework::Variable* var_;
VarBase* grads_;
std::shared_ptr<framework::Variable> var_;
std::shared_ptr<VarBase> grads_;
bool stop_gradient_;
};
......
......@@ -74,10 +74,10 @@ class Tracer {
for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first];
for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
PADDLE_ENFORCE_NOT_NULL(inp->var_.get(), "op %s input %s nullptr",
op->op_desc_->Type(), inp->var_desc_->Name());
invars.push_back(inp->var_);
invars.push_back(inp->var_.get());
vars[inp->var_desc_->Name()] = inp;
if (inp->pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_);
......@@ -96,7 +96,7 @@ class Tracer {
const std::vector<VarBase*>& outputs = it.second;
for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i];
outvars.push_back(out->var_);
outvars.push_back(out->var_.get());
vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
......@@ -143,13 +143,13 @@ class Tracer {
if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end());
grad_in_vars.push_back(fwd_var_it->second->var_);
grad_in_vars.push_back(fwd_var_it->second->var_.get());
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_);
InitVar(var->var_.get(), var->grads_->var_.get());
}
grad_in_vars.push_back(var->grads_->var_);
grad_in_vars.push_back(var->grads_->var_.get());
}
}
}
......@@ -162,9 +162,9 @@ class Tracer {
PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_);
InitVar(var->var_.get(), var->grads_->var_.get());
}
grad_out_vars.push_back(var->grads_->var_);
grad_out_vars.push_back(var->grads_->var_.get());
}
}
}
......
......@@ -132,11 +132,12 @@ PYBIND11_MODULE(core, m) {
.def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName)
.def("_grad", &imperative::VarBase::Grad)
.def("_grad_value", &imperative::VarBase::GradValue)
.def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; },
[](const imperative::VarBase &self) { return self.grads_.get(); },
py::return_value_policy::reference)
.def("value", [](const imperative::VarBase &self) { return self.var_; },
.def("value",
[](const imperative::VarBase &self) { return self.var_.get(); },
py::return_value_policy::reference)
.def_property(
"desc",
......
......@@ -379,7 +379,7 @@ class Variable(object):
self._ivar._run_backward()
def _gradient(self):
return np.array(self._ivar._grad())
return np.array(self._ivar._grad_value())
def __str__(self):
return self.to_string(True)
......
......@@ -46,7 +46,6 @@ def to_variable(value, block=None):
shape=value.shape,
dtype=value.dtype)
var = py_var._ivar.value()
print(type(var))
tensor = var.get_tensor()
tensor.set(value, core.CPUPlace())
return py_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册