未验证 提交 e5fef8f3 编写于 作者: Z Zeng Jinle 提交者: GitHub

[Dygraph double grad]Code polish (#23121)

* fix dygraph double grad, test=develop

* fix unpack constructor, test=develop
上级 9258e960
......@@ -210,15 +210,10 @@ std::string LayerDebugString(const std::string& op_type,
return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
}
VarBase::VarBase(bool has_grad, const std::shared_ptr<VariableWrapper>& var)
VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var)
: var_(var), grad_node_(var->GetGradNode()) {
if (has_grad) {
if (auto grad_var = var_->GetGradVar()) {
grad_var_ = std::make_shared<VarBase>(false, grad_var);
} else {
grad_var_ = std::make_shared<VarBase>(false, GradVarName());
var_->SetGradVar(grad_var_->var_);
}
grad_var_ = std::make_shared<VarBase>(grad_var);
}
if (IsDebugEnabled()) {
......@@ -417,10 +412,10 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs);
if (grad_node && !grad_node->empty()) {
for (auto& op : *grad_node) {
op.SetId(OpBase::GenerateUniqueId());
op.SetPlace(place);
ClearNoNeedBufferInputs(&op);
for (auto& grad_op : *grad_node) {
grad_op.SetId(OpBase::GenerateUniqueId());
grad_op.SetPlace(place);
ClearNoNeedBufferInputs(&grad_op);
}
return grad_node;
} else {
......
......@@ -76,7 +76,8 @@ class VarBase {
explicit VarBase(const std::string& name) : VarBase(true, name) {}
// NOTE(zengjinle): be careful when you use this constructor!!!
explicit VarBase(bool has_grad, const std::shared_ptr<VariableWrapper>& var);
// Unpack VarBase from VariableWrapper.
explicit VarBase(const std::shared_ptr<VariableWrapper>& var);
~VarBase() {
VLOG(10) << "Destruct VarBase: " << Name();
......@@ -100,7 +101,7 @@ class VarBase {
const std::shared_ptr<VarBase>& MutableGradVarBase() {
if (grad_var_ == nullptr) {
if (auto grad_var_wrapper = var_->GetGradVar()) {
grad_var_ = std::make_shared<VarBase>(false, grad_var_wrapper);
grad_var_ = std::make_shared<VarBase>(grad_var_wrapper);
} else {
grad_var_ = std::make_shared<VarBase>(false, GradVarName());
var_->SetGradVar(grad_var_->var_);
......
......@@ -719,7 +719,7 @@ PartialGradTask::PartialGradTask(
auto grad_accumulator_iter = grad_accumulators_.find(mapped_out_grad_var);
if (grad_accumulator_iter == grad_accumulators_.end()) {
ready_grad_vars_.Set(mapped_out_grad_var,
std::make_shared<VarBase>(false, out_grad_var));
std::make_shared<VarBase>(out_grad_var));
VLOG(10) << "Fill 1.0f or user-provided gradient as ready var "
<< out_grad_var->Name();
} else {
......@@ -783,7 +783,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
if (!input_pair.second.IsGrad()) {
for (auto &fwd_var : input_pair.second) {
if (fwd_var) {
new_inputs.emplace_back(new VarBase(true, fwd_var));
new_inputs.emplace_back(new VarBase(fwd_var));
VLOG(10) << "Unpacked forward var " << fwd_var->Name()
<< ", grad ops: " << GradOpTypes(*new_inputs.back());
} else {
......@@ -813,7 +813,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
for (auto &fwd_var : output_pair.second) {
// unpack forward var
if (fwd_var) {
new_outputs.emplace_back(new VarBase(true, fwd_var));
new_outputs.emplace_back(new VarBase(fwd_var));
VLOG(10) << "Unpacked forward var " << fwd_var->Name();
} else {
new_outputs.emplace_back();
......@@ -878,18 +878,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
auto partial_grad_grads = accumulator_info->SumGradient(
std::move(grad_var), op->id(), &is_finished);
if (is_finished) {
VLOG(10) << "Sum has finished for "
<< accumulator_info->MappedGradVar()->Name() << " "
<< accumulator_info->GradVarBase();
ready_grad_vars_.Set(accumulator_info->MappedGradVar(),
accumulator_info->GradVarBase());
}
if (partial_grad_grads.empty()) {
continue;
}
if (!partial_grad_grads.empty()) {
auto sum_grad_var_grad =
accumulator_info->GradVarBase()->MutableGradVarBase();
sum_grad_var_grad->SetOverridedStopGradient(false);
......@@ -913,11 +902,21 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
assign_node->InsertGradPendingNode(std::move(grad_pending_node));
}
}
VLOG(10) << "Pending ops of assign is " << GradPendingOpTypes(*assign_node);
grad_accumulators_.erase(accumulator_info->MappedGradVar());
VLOG(10) << "Pending ops of assign is "
<< GradPendingOpTypes(*assign_node);
double_grad_nodes_.emplace_back(assign_node);
}
if (is_finished) {
VLOG(10) << "Sum has finished for "
<< accumulator_info->MappedGradVar()->Name() << " "
<< accumulator_info->GradVarBase();
ready_grad_vars_.Set(accumulator_info->MappedGradVar(),
accumulator_info->GradVarBase());
grad_accumulators_.erase(accumulator_info->MappedGradVar());
}
}
grads_to_accumulate_.clear();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册