提交 a15a3fc3 编写于 作者: M minqiyang

Polish code

test=develop
上级 9dc64edf
...@@ -163,7 +163,7 @@ std::vector<OpDesc *> BlockDesc::AllOps() const { ...@@ -163,7 +163,7 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
return res; return res;
} }
void BlockDesc::ClearBlock() { void BlockDesc::Clear() {
// clear all ops // clear all ops
ops_.clear(); ops_.clear();
......
...@@ -97,7 +97,7 @@ class BlockDesc { ...@@ -97,7 +97,7 @@ class BlockDesc {
std::vector<OpDesc *> AllOps() const; std::vector<OpDesc *> AllOps() const;
void ClearBlock(); void Clear();
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
......
...@@ -205,33 +205,6 @@ framework::LoDTensor& VarBase::GradValue() { ...@@ -205,33 +205,6 @@ framework::LoDTensor& VarBase::GradValue() {
return *(grads_->var_->GetMutable<framework::LoDTensor>()); return *(grads_->var_->GetMutable<framework::LoDTensor>());
} }
void VarBase::ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
void VarBase::RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (grad_op_descs_.empty() && backward_id_ <= 0) { if (grad_op_descs_.empty() && backward_id_ <= 0) {
VLOG(3) << "op with no grad: " << op_desc_->Type(); VLOG(3) << "op with no grad: " << op_desc_->Type();
......
...@@ -150,9 +150,32 @@ class VarBase { ...@@ -150,9 +150,32 @@ class VarBase {
} }
} }
void RunBackward(); void RunBackward() {
if (!pre_op_) return;
void ClearGradient(); VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
......
...@@ -145,7 +145,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -145,7 +145,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op.func(framework::ExecutionContext( prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
std::set<std::string> grad_deps_var; std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) { if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
...@@ -166,7 +166,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -166,7 +166,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.push_back(fwd_var_it->second->var_);
grad_deps_var.insert(it.first); vars_saved_for_backward.insert(it.first);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
...@@ -200,7 +200,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -200,7 +200,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} }
op->block_ = block; op->block_ = block;
return grad_deps_var; return vars_saved_for_backward;
} }
std::vector<VarBase*> Tracer::PyTrace(OpBase* op, std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
......
...@@ -189,8 +189,7 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -189,8 +189,7 @@ void BindBlockDesc(pybind11::module *m) {
return self.HasVar(name); return self.HasVar(name);
}, },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("_clear_block", .def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); },
[](pd::BlockDesc &self) { return self.ClearBlock(); },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("_rename_var", .def("_rename_var",
[](pd::BlockDesc &self, const pybind11::bytes &byte_name, [](pd::BlockDesc &self, const pybind11::bytes &byte_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册