未验证 提交 e5fef8f3 编写于 作者: Z Zeng Jinle 提交者: GitHub

[Dygraph double grad]Code polish (#23121)

* fix dygraph double grad, test=develop

* fix unpack constructor, test=develop
上级 9258e960
...@@ -210,15 +210,10 @@ std::string LayerDebugString(const std::string& op_type, ...@@ -210,15 +210,10 @@ std::string LayerDebugString(const std::string& op_type,
return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs); return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
} }
VarBase::VarBase(bool has_grad, const std::shared_ptr<VariableWrapper>& var) VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var)
: var_(var), grad_node_(var->GetGradNode()) { : var_(var), grad_node_(var->GetGradNode()) {
if (has_grad) {
if (auto grad_var = var_->GetGradVar()) { if (auto grad_var = var_->GetGradVar()) {
grad_var_ = std::make_shared<VarBase>(false, grad_var); grad_var_ = std::make_shared<VarBase>(grad_var);
} else {
grad_var_ = std::make_shared<VarBase>(false, GradVarName());
var_->SetGradVar(grad_var_->var_);
}
} }
if (IsDebugEnabled()) { if (IsDebugEnabled()) {
...@@ -417,10 +412,10 @@ std::shared_ptr<GradOpNode> CreateGradOpNode( ...@@ -417,10 +412,10 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs); auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs);
if (grad_node && !grad_node->empty()) { if (grad_node && !grad_node->empty()) {
for (auto& op : *grad_node) { for (auto& grad_op : *grad_node) {
op.SetId(OpBase::GenerateUniqueId()); grad_op.SetId(OpBase::GenerateUniqueId());
op.SetPlace(place); grad_op.SetPlace(place);
ClearNoNeedBufferInputs(&op); ClearNoNeedBufferInputs(&grad_op);
} }
return grad_node; return grad_node;
} else { } else {
......
...@@ -76,7 +76,8 @@ class VarBase { ...@@ -76,7 +76,8 @@ class VarBase {
explicit VarBase(const std::string& name) : VarBase(true, name) {} explicit VarBase(const std::string& name) : VarBase(true, name) {}
// NOTE(zengjinle): be careful when you use this constructor!!! // NOTE(zengjinle): be careful when you use this constructor!!!
explicit VarBase(bool has_grad, const std::shared_ptr<VariableWrapper>& var); // Unpack VarBase from VariableWrapper.
explicit VarBase(const std::shared_ptr<VariableWrapper>& var);
~VarBase() { ~VarBase() {
VLOG(10) << "Destruct VarBase: " << Name(); VLOG(10) << "Destruct VarBase: " << Name();
...@@ -100,7 +101,7 @@ class VarBase { ...@@ -100,7 +101,7 @@ class VarBase {
const std::shared_ptr<VarBase>& MutableGradVarBase() { const std::shared_ptr<VarBase>& MutableGradVarBase() {
if (grad_var_ == nullptr) { if (grad_var_ == nullptr) {
if (auto grad_var_wrapper = var_->GetGradVar()) { if (auto grad_var_wrapper = var_->GetGradVar()) {
grad_var_ = std::make_shared<VarBase>(false, grad_var_wrapper); grad_var_ = std::make_shared<VarBase>(grad_var_wrapper);
} else { } else {
grad_var_ = std::make_shared<VarBase>(false, GradVarName()); grad_var_ = std::make_shared<VarBase>(false, GradVarName());
var_->SetGradVar(grad_var_->var_); var_->SetGradVar(grad_var_->var_);
......
...@@ -719,7 +719,7 @@ PartialGradTask::PartialGradTask( ...@@ -719,7 +719,7 @@ PartialGradTask::PartialGradTask(
auto grad_accumulator_iter = grad_accumulators_.find(mapped_out_grad_var); auto grad_accumulator_iter = grad_accumulators_.find(mapped_out_grad_var);
if (grad_accumulator_iter == grad_accumulators_.end()) { if (grad_accumulator_iter == grad_accumulators_.end()) {
ready_grad_vars_.Set(mapped_out_grad_var, ready_grad_vars_.Set(mapped_out_grad_var,
std::make_shared<VarBase>(false, out_grad_var)); std::make_shared<VarBase>(out_grad_var));
VLOG(10) << "Fill 1.0f or user-provided gradient as ready var " VLOG(10) << "Fill 1.0f or user-provided gradient as ready var "
<< out_grad_var->Name(); << out_grad_var->Name();
} else { } else {
...@@ -783,7 +783,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) { ...@@ -783,7 +783,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
if (!input_pair.second.IsGrad()) { if (!input_pair.second.IsGrad()) {
for (auto &fwd_var : input_pair.second) { for (auto &fwd_var : input_pair.second) {
if (fwd_var) { if (fwd_var) {
new_inputs.emplace_back(new VarBase(true, fwd_var)); new_inputs.emplace_back(new VarBase(fwd_var));
VLOG(10) << "Unpacked forward var " << fwd_var->Name() VLOG(10) << "Unpacked forward var " << fwd_var->Name()
<< ", grad ops: " << GradOpTypes(*new_inputs.back()); << ", grad ops: " << GradOpTypes(*new_inputs.back());
} else { } else {
...@@ -813,7 +813,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) { ...@@ -813,7 +813,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
for (auto &fwd_var : output_pair.second) { for (auto &fwd_var : output_pair.second) {
// unpack forward var // unpack forward var
if (fwd_var) { if (fwd_var) {
new_outputs.emplace_back(new VarBase(true, fwd_var)); new_outputs.emplace_back(new VarBase(fwd_var));
VLOG(10) << "Unpacked forward var " << fwd_var->Name(); VLOG(10) << "Unpacked forward var " << fwd_var->Name();
} else { } else {
new_outputs.emplace_back(); new_outputs.emplace_back();
...@@ -878,18 +878,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) { ...@@ -878,18 +878,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
auto partial_grad_grads = accumulator_info->SumGradient( auto partial_grad_grads = accumulator_info->SumGradient(
std::move(grad_var), op->id(), &is_finished); std::move(grad_var), op->id(), &is_finished);
if (is_finished) { if (!partial_grad_grads.empty()) {
VLOG(10) << "Sum has finished for "
<< accumulator_info->MappedGradVar()->Name() << " "
<< accumulator_info->GradVarBase();
ready_grad_vars_.Set(accumulator_info->MappedGradVar(),
accumulator_info->GradVarBase());
}
if (partial_grad_grads.empty()) {
continue;
}
auto sum_grad_var_grad = auto sum_grad_var_grad =
accumulator_info->GradVarBase()->MutableGradVarBase(); accumulator_info->GradVarBase()->MutableGradVarBase();
sum_grad_var_grad->SetOverridedStopGradient(false); sum_grad_var_grad->SetOverridedStopGradient(false);
...@@ -913,11 +902,21 @@ void PartialGradTask::RunEachOp(const OpBase *op) { ...@@ -913,11 +902,21 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
assign_node->InsertGradPendingNode(std::move(grad_pending_node)); assign_node->InsertGradPendingNode(std::move(grad_pending_node));
} }
} }
VLOG(10) << "Pending ops of assign is " << GradPendingOpTypes(*assign_node); VLOG(10) << "Pending ops of assign is "
grad_accumulators_.erase(accumulator_info->MappedGradVar()); << GradPendingOpTypes(*assign_node);
double_grad_nodes_.emplace_back(assign_node); double_grad_nodes_.emplace_back(assign_node);
} }
if (is_finished) {
VLOG(10) << "Sum has finished for "
<< accumulator_info->MappedGradVar()->Name() << " "
<< accumulator_info->GradVarBase();
ready_grad_vars_.Set(accumulator_info->MappedGradVar(),
accumulator_info->GradVarBase());
grad_accumulators_.erase(accumulator_info->MappedGradVar());
}
}
grads_to_accumulate_.clear(); grads_to_accumulate_.clear();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册