提交 a15a3fc3 编写于 作者: M minqiyang

Polish code

test=develop
上级 9dc64edf
......@@ -163,7 +163,7 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
return res;
}
void BlockDesc::ClearBlock() {
void BlockDesc::Clear() {
// clear all ops
ops_.clear();
......
......@@ -97,7 +97,7 @@ class BlockDesc {
std::vector<OpDesc *> AllOps() const;
void ClearBlock();
void Clear();
size_t OpSize() const { return ops_.size(); }
......
......@@ -205,33 +205,6 @@ framework::LoDTensor& VarBase::GradValue() {
return *(grads_->var_->GetMutable<framework::LoDTensor>());
}
void VarBase::ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
void VarBase::RunBackward() {
if (!pre_op_) return;
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (grad_op_descs_.empty() && backward_id_ <= 0) {
VLOG(3) << "op with no grad: " << op_desc_->Type();
......
......@@ -150,9 +150,32 @@ class VarBase {
}
}
void RunBackward();
void RunBackward() {
if (!pre_op_) return;
void ClearGradient();
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
Autograd().RunBackward(this);
}
void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
grads_->var_->Get<framework::LoDTensor>().place())),
grads_t, 0.0);
}
}
framework::LoDTensor& GradValue();
......
......@@ -145,7 +145,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
std::set<std::string> grad_deps_var;
std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
......@@ -166,7 +166,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE(fwd_var_it != vars.end());
// Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_);
grad_deps_var.insert(it.first);
vars_saved_for_backward.insert(it.first);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
......@@ -200,7 +200,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
}
op->block_ = block;
return grad_deps_var;
return vars_saved_for_backward;
}
std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
......
......@@ -189,8 +189,7 @@ void BindBlockDesc(pybind11::module *m) {
return self.HasVar(name);
},
pybind11::return_value_policy::reference)
.def("_clear_block",
[](pd::BlockDesc &self) { return self.ClearBlock(); },
.def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); },
pybind11::return_value_policy::reference)
.def("_rename_var",
[](pd::BlockDesc &self, const pybind11::bytes &byte_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册