diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index b37d8619e7e680f368ee87d3f386e6b332a3a50b..f97ab4f4e05313f52a37b3c483741214192bd67b 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -38,7 +38,20 @@ namespace imperative { void BasicEngine::Init(VarBase* var, bool retain_graph) { retain_graph_ = retain_graph; init_node_ = var->GradVarBase()->GradNode(); - var->GradVarBase()->ClearGradNode(); + PADDLE_ENFORCE_EQ(var->GradVarBase()->GraphIsFreed(), false, + platform::errors::Unavailable( + "%s trying to backward through the same graph a second " + "time, but this graph have already been freed. Please " + "specify Tensor.backward(retain_graph=True) when " + "calling backward at the first time.", + var->Name())); + + if (!retain_graph) { + VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() + << " because of retain_graph=False when calling backward"; + var->GradVarBase()->SetGraphIsFreed(true); + var->GradVarBase()->ClearGradNode(); + } if (init_node_ == nullptr || var->OverridedStopGradient()) { VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " @@ -47,7 +60,7 @@ void BasicEngine::Init(VarBase* var, bool retain_graph) { return; } - VLOG(3) << "start backward"; + VLOG(3) << "Init first node of backward"; PADDLE_ENFORCE_EQ( var->HasGradVar(), true, @@ -114,6 +127,10 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) { accumulator->IncreaseRefCnt(); + VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" + << var.get() << ") with reference count " + << accumulator->RefCnt(); + if (var->HasLeafHooks()) { VLOG(3) << "Grad variable wrapper (" << var->Name() << ") has leaf grad hooks."; @@ -123,10 +140,6 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) { "Gradientaccumulator.")); accumulator->SetPostHooks(var->GetLeafHooks()); } - - VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" - << var.get() << ") with reference count " - << accumulator->RefCnt(); } } } @@ -190,13 +203,14 @@ void BasicEngine::Execute() { // CheckBackWardInput CheckBackwardInputs(cur_op); - // Step 1: Run Backward + // Step 1: Run Backward OP auto& bwd_ins = cur_op.GetInsMap(); auto& bwd_outs = cur_op.GetOutsMap(); NameVarMap<VariableWrapper> tmp_outs(bwd_outs); - // 1. construct the output map 2. replace the element in the map - // A var may be coresponding to several grad var in one op + // 1. construct the temp output map, avoid to disrupt graph + // 2. replace the element in the map by temp var, because a + // var may be coresponding to several grad var in one op for (auto& pair : tmp_outs) { if (!pair.second.IsGrad()) { continue; @@ -213,15 +227,23 @@ void BasicEngine::Execute() { platform::errors::NotFound("Cannot find gradient of variable %s", var->Name())); - if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) { - no_need_run_accumulators_.emplace_back(iter->second.get()); - continue; + // leaf_accumulators_ : hooks and accumulate-grad for leaf tensor + if (var->IsLeafGrad()) { + leaf_accumulators_.insert(iter->second.get()); + + if (iter->second->HasInnerVar()) { + var = iter->second->InnerVar(); + } } - auto tmp_var = std::make_shared<VariableWrapper>(var->Name()); - tmp_var->SetType(var->Type()); - var = tmp_var; - need_accu_var_list_.emplace_back(iter->second.get(), var); + if (var->OverridedStopGradient() || iter->second->RefCnt() > 1) { + auto tmp_var = std::make_shared<VariableWrapper>(var->Name()); + tmp_var->SetType(var->Type()); + var = tmp_var; + need_accu_var_list_.emplace_back(iter->second.get(), var); + VLOG(10) << "create temporary var of " << var->Name() + << " for sum gradient within this graph!"; + } } } @@ -256,22 +278,32 @@ void BasicEngine::Execute() { cur_op.place()); } - // Step 2: Sum Gradient & Call Accumulator Hooks - for (auto* accumulator : no_need_run_accumulators_) { + // Step 2: Sum Gradient of This graph + for (auto& pair : need_accu_var_list_) { + pair.first->SumGrad(std::move(pair.second), cur_op.id()); + } + + // Step 3: Call Hooks && Sum Gradient with Pre-Graph && Call BackwardHooks + for (auto* accumulator : leaf_accumulators_) { + if (!accumulator->SumGradCompleted()) { + continue; + } + // 1. Call Hooks for **inner_var_** + + // 2. Sum Gradient with Previous Graph + accumulator->AccumulateGrad(); + + // 3. Call backward Hooks for **var_** if (accumulator->HasPostHooks()) { accumulator->CallBackwardPostHooks(); } } - for (auto& pair : need_accu_var_list_) { - pair.first->Add(std::move(pair.second), cur_op.id()); - } - need_accu_var_list_.clear(); - no_need_run_accumulators_.clear(); + leaf_accumulators_.clear(); - VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; if (!retain_graph_) { + VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; cur_op.ClearBackwardTrace(); } } @@ -301,7 +333,7 @@ void BasicEngine::Clear() { node_deps_.clear(); accumulators_.clear(); need_accu_var_list_.clear(); - no_need_run_accumulators_.clear(); + leaf_accumulators_.clear(); } } // namespace imperative diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h index 92e7fe7eb8cd792a9bf644ea166f8ac189163141..d7ac7594ef027cc98196d894ea551b21ff92081a 100644 --- a/paddle/fluid/imperative/basic_engine.h +++ b/paddle/fluid/imperative/basic_engine.h @@ -16,6 +16,7 @@ #include <memory> #include <unordered_map> +#include <unordered_set> #include <utility> #include <vector> #include "paddle/fluid/imperative/engine.h" @@ -49,9 +50,9 @@ class BasicEngine : public Engine { accumulators_; std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>> need_accu_var_list_; - // Accumulators that does not need to perform accumulation operations, - // the ref_cnt_=1, corresponding to need_accu_var_list_ - std::vector<GradientAccumulator*> no_need_run_accumulators_; + // leaf_accumulators_ is only for leaf tensor(hooks/accumulate grad) + std::unordered_set<GradientAccumulator*> leaf_accumulators_; + bool retain_graph_; }; diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index 0d81221c43306ce35f8dc038456af0d04830e365..d650452ad9a384ed79acee302eaac7f4a1ec2b0b 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -219,6 +219,7 @@ class TracedGradOp { if (kRole == TracedVarRole::kBackward) { for (auto& var : vars) { if (var && !var->OverridedStopGradient()) { + var->SetGraphIsFreed(false); var->SetGradNode(node_); } } diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 00fd18e5e2564ca408f895baa1c868855916d8f5..66c4d1c5f55ab99e56763876f4730ea948388baf 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -35,11 +35,12 @@ namespace imperative { static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src, bool force_copy) { if (!force_copy) { + VLOG(6) << "Just Move Variable when sum gradients within this graph"; *dst = std::move(*src); return; } - VLOG(10) << "Copy occurs when accumulating gradients"; + VLOG(6) << "Copy occurs when sum gradients within this graph"; if (src->IsType<framework::LoDTensor>()) { auto& src_tensor = src->Get<framework::LoDTensor>(); if (!dst->IsType<framework::LoDTensor>()) { @@ -61,7 +62,7 @@ static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src, dst_selected_rows->set_height(src_selected_rows.height()); } else { PADDLE_THROW(platform::errors::PermissionDenied( - "Only support LoDTensor and SelectedRows for gradient accumulation")); + "Only support LoDTensor and SelectedRows for sum gradient")); } } @@ -313,9 +314,9 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge( } void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var, - VariableWrapper* var_, bool unchange_input) { + VariableWrapper* dst_var, bool unchange_input) { auto& src = var->Var(); - auto* dst = var_->MutableVar(); + auto* dst = dst_var->MutableVar(); if (dst->IsType<framework::LoDTensor>()) { if (src.IsType<framework::LoDTensor>()) { TensorAdd(src, dst); @@ -362,8 +363,57 @@ static platform::Place GetPlaceOfVar( return place; } -void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, - size_t trace_id, bool unchange_input) { +void GradientAccumulator::AccumulateGrad() { + /** + * If the gradient has been calculated by previous graph, + * it should be added to the previous graph result. + */ + if (!var_->IsLeafGrad() || !SumGradCompleted() || !HasInnerVar()) { + return; + } + PADDLE_ENFORCE_EQ(HasInnerVar(), true, + platform::errors::InvalidArgument( + "Leaf tensor should have inner var to store results of " + "this auto-grad")); + PADDLE_ENFORCE_EQ(inner_var_->Var().IsInitialized(), true, + platform::errors::InvalidArgument( + "Interior var of Leaf tensor should be initialized.")); + auto* src = inner_var_->MutableVar(); + auto* dst = var_->MutableVar(); + if (!var_->IsEmpty()) { + VLOG(6) << "Leaf Gradient Var(" << var_->Name() + << ") has been calculated by previous graph, will accumulate on " + "previous graph."; + if (dst->IsType<framework::LoDTensor>()) { + if (src->IsType<framework::LoDTensor>()) { + TensorAdd(*src, dst); + } else if (src->IsType<framework::SelectedRows>()) { + SelectedRowsAddToTensor(*src, dst); + } + } else if (dst->IsType<framework::SelectedRows>()) { + if (src->IsType<framework::LoDTensor>()) { + SelectedRowsAddToTensor(*dst, src); + *dst = std::move(*src); + } else if (src->IsType<framework::SelectedRows>()) { + auto temp = SelectedRowsMerge(*src, *dst); + *dst = std::move(*(temp->MutableVar())); + } + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Only support LoDTensor and SelectedRows for gradient var")); + } + } else { + VLOG(6) << "Leaf Gradient Var(" << var_->Name() + << ") has not been initialized, not accumulate. Just move"; + *(dst) = std::move(*src); + var_->SetType(inner_var_->Type()); + var_->SetDataType(inner_var_->DataType()); + } + inner_var_.reset(); +} + +void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, + size_t trace_id, bool unchange_input) { /** * If var has grad node, it indicates that this var would be an input * of a grad op. Therefore, it should not be changed. @@ -372,53 +422,57 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, unchange_input = true; } - auto* dst_var = var_->MutableVar(); + auto* dst_var = Var(); platform::Place place = GetPlaceOfVar(var); - if (!var_->OverridedStopGradient()) { - VLOG(3) << "Sum Gradient for: " << var_->Name(); - if (cur_cnt_ == 0) { - MoveOrCopyVar(dst_var, var->MutableVar(), unchange_input); + if (!dst_var->OverridedStopGradient()) { + if (CurCnt() == 0) { + MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input); } else { - VariableWrapperAdd(var, var_, unchange_input); + VLOG(6) << "Sum Gradient for: " << dst_var->Name() + << " within this graph."; + VariableWrapperAdd(var, dst_var, unchange_input); } } else { - if (!var_->Var().IsInitialized() || - !var_->Var().Get<framework::LoDTensor>().IsInitialized()) { - VLOG(6) << "Set StopGradient Grad: " << var_->Name() << " as zero "; - + if (!dst_var->Var().IsInitialized() || + !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) { + VLOG(6) << "Set StopGradient Grad: " << dst_var->Name() << " as zero "; auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - if (!var_->Var().IsInitialized()) { - auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>(); - VLOG(6) << "Dims of " << var_->Name() << " is set as: " + if (!dst_var->Var().IsInitialized()) { + auto* tensor = + dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); + VLOG(6) << "Dims of " << dst_var->Name() << " is set as: " << var->Var().Get<framework::LoDTensor>().dims(); tensor->Resize(var->Var().Get<framework::LoDTensor>().dims()); tensor->mutable_data(place, var->DataType()); operators::math::set_constant(*dev_ctx, tensor, 0.0); } else { - auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>(); + auto* tensor = + dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); tensor->mutable_data(place, var->DataType()); operators::math::set_constant(*dev_ctx, tensor, 0.0); } } } - if (var_->Var().IsType<framework::LoDTensor>()) { - var_->SetType(framework::proto::VarType::LOD_TENSOR); - } else if (var_->Var().IsType<framework::SelectedRows>()) { - var_->SetType(framework::proto::VarType::SELECTED_ROWS); + // Type may be changed after OP run, such as VarTypeInference + // so synchronous VariableWrapper with Variable. + if (dst_var->Var().IsType<framework::LoDTensor>()) { + dst_var->SetType(framework::proto::VarType::LOD_TENSOR); + } else if (dst_var->Var().IsType<framework::SelectedRows>()) { + dst_var->SetType(framework::proto::VarType::SELECTED_ROWS); } - // Increase count & call post hooks + // Increase curent count IncreaseCurCnt(); } -void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, - size_t trace_id, bool unchange_input) { - auto* dst_var = var_->MutableVar(); +void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, + size_t trace_id, bool unchange_input) { + auto* dst_var = Var(); platform::Place place = GetPlaceOfVar(var); - if (!var_->OverridedStopGradient()) { + if (!dst_var->OverridedStopGradient()) { if (ref_cnt_ == 1) { - MoveOrCopyVar(dst_var, var->MutableVar(), + MoveOrCopyVar(dst_var->MutableVar(), var->MutableVar(), unchange_input || var->HasGradNode()); } else { if (tmp_grad_vars_.empty()) { @@ -431,6 +485,8 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, return; } + VLOG(6) << "Sum Gradient for: " << dst_var->Name() + << " within this graph."; std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(), [](const SavedVarInfo& info1, const SavedVarInfo& info2) { return info1.trace_id > info2.trace_id; @@ -444,22 +500,22 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, #ifdef PADDLE_WITH_CUDA if (paddle::platform::is_gpu_place(place)) { - bool dst_varbase_is_initialized = false; - // accumulate selected rows firstly + // sum selected rows firstly for (auto& var_info : tmp_grad_vars_) { if (!var_info.var->Var().IsType<framework::SelectedRows>()) { continue; } - if (!dst_varbase_is_initialized) { - dst_varbase_is_initialized = true; - MoveOrCopyVar(dst_var, var_info.var->MutableVar(), + if (CurCnt() == 0) { + MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(), var_info.unchange_input); } else { - VariableWrapperAdd(var_info.var, var_, var_info.unchange_input); + VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input); } var_info.var = nullptr; + // Increase count + IncreaseCurCnt(); } for (auto& var_info : tmp_grad_vars_) { @@ -470,25 +526,38 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(), true, platform::errors::PermissionDenied( "Gradient var must be LoDTensor")); - - if (!dst_varbase_is_initialized) { - dst_varbase_is_initialized = true; - MoveOrCopyVar(dst_var, var_info.var->MutableVar(), + if (CurCnt() == 0) { + MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(), var_info.unchange_input); } else { - VariableWrapperAdd(var_info.var, var_, var_info.unchange_input); + VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input); } var_info.var = nullptr; + // Increase count + IncreaseCurCnt(); } } else { #endif - MoveOrCopyVar(dst_var, tmp_grad_vars_[0].var->MutableVar(), - tmp_grad_vars_[0].unchange_input); - for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) { - VariableWrapperAdd(tmp_grad_vars_[i].var, var_, - tmp_grad_vars_[i].unchange_input); - tmp_grad_vars_[i].var = nullptr; + for (auto& var_info : tmp_grad_vars_) { + if (!var_info.var) { + continue; + } + PADDLE_ENFORCE_EQ( + var_info.var->Var().IsType<framework::LoDTensor>() || + var_info.var->Var().IsType<framework::SelectedRows>(), + true, platform::errors::PermissionDenied("The type of Gradient " + "var must be LoDTensor " + "or SelectedRows")); + if (CurCnt() == 0) { + MoveOrCopyVar(dst_var->MutableVar(), var_info.var->MutableVar(), + var_info.unchange_input); + } else { + VariableWrapperAdd(var_info.var, dst_var, var_info.unchange_input); + } + var_info.var = nullptr; + // Increase count + IncreaseCurCnt(); } #ifdef PADDLE_WITH_CUDA } @@ -496,19 +565,21 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, tmp_grad_vars_.clear(); } } else { - if (!var_->Var().IsInitialized() || - !var_->Var().Get<framework::LoDTensor>().IsInitialized()) { + if (!dst_var->Var().IsInitialized() || + !dst_var->Var().Get<framework::LoDTensor>().IsInitialized()) { VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero"; auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - if (!var_->Var().IsInitialized()) { - auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>(); - VLOG(6) << "Dims of " << var_->Name() << " is set as: " + if (!dst_var->Var().IsInitialized()) { + auto* tensor = + dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); + VLOG(6) << "Dims of " << dst_var->Name() << " is set as: " << var->Var().Get<framework::LoDTensor>().dims(); tensor->Resize(var->Var().Get<framework::LoDTensor>().dims()); tensor->mutable_data(place, var->DataType()); operators::math::set_constant(*dev_ctx, tensor, 0.0); } else { - auto* tensor = var_->MutableVar()->GetMutable<framework::LoDTensor>(); + auto* tensor = + dst_var->MutableVar()->GetMutable<framework::LoDTensor>(); tensor->mutable_data(place, var->DataType()); operators::math::set_constant(*dev_ctx, tensor, 0.0); } @@ -517,15 +588,10 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var, tmp_grad_vars_.clear(); } - if (var_->Var().IsType<framework::LoDTensor>()) { - var_->SetType(framework::proto::VarType::LOD_TENSOR); - } else if (var_->Var().IsType<framework::SelectedRows>()) { - var_->SetType(framework::proto::VarType::SELECTED_ROWS); - } - - // call post hooks - if (HasPostHooks()) { - CallBackwardPostHooks(); + if (dst_var->Var().IsType<framework::LoDTensor>()) { + dst_var->SetType(framework::proto::VarType::LOD_TENSOR); + } else if (dst_var->Var().IsType<framework::SelectedRows>()) { + dst_var->SetType(framework::proto::VarType::SELECTED_ROWS); } } diff --git a/paddle/fluid/imperative/gradient_accumulator.h b/paddle/fluid/imperative/gradient_accumulator.h index 2d0cc6e892159083d019037d596783ca5a9964e5..ab5ec52fb2adab5c34931d842b4aa526a493ac2d 100644 --- a/paddle/fluid/imperative/gradient_accumulator.h +++ b/paddle/fluid/imperative/gradient_accumulator.h @@ -26,17 +26,72 @@ namespace imperative { class GradientAccumulator { public: - explicit GradientAccumulator(VariableWrapper* var) : var_(var) {} + explicit GradientAccumulator(VariableWrapper* var) { + // var may be initialized, so Synchronous VariableWrapper with Variable + if (var && var->Var().IsInitialized()) { + if (var->Var().IsType<framework::LoDTensor>()) { + var->SetType(framework::proto::VarType::LOD_TENSOR); + } else if (var->Var().IsType<framework::SelectedRows>()) { + var->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Only support LoDTensor and SelectedRows for gradient var")); + } + } + + // inner_var_ record the grad of this auto-grad. + // Only need to generate inner var for non-empty leaf-tensor. + if (var->IsLeafGrad() && !var->IsEmpty()) { + inner_var_ = std::make_shared<VariableWrapper>(var->Name()); + inner_var_->SetType(var->Type()); + inner_var_->SetDataType(var->DataType()); + inner_var_->InnerSetOverridedStopGradient( + var->InnerOverridedStopGradient()); + VLOG(6) << " Create inner grad var for (" << var->Name() + << ") to store result of this Graph"; + } + + // TODO(zhouwei): fix Tensor.clear_gradient() bug, remove this hard flag + var->SetIsEmpty(false); - virtual void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id, - bool unchange_input = false) = 0; + // var_ is the final grad, processed by hooks and grad accumulation + var_ = var; + } + + // function that Sum Gradient with this Graph + virtual void SumGrad(std::shared_ptr<VariableWrapper> var, size_t trace_id, + bool unchange_input = false) = 0; virtual ~GradientAccumulator() = default; - inline void IncreaseRefCnt() { ++ref_cnt_; } + inline void IncreaseRefCnt() { + ++ref_cnt_; + VLOG(6) << var_->Name() << " Increase total count to " << ref_cnt_; + } + + inline void IncreaseCurCnt() { + ++cur_cnt_; + VLOG(6) << var_->Name() << " Increase current count to " << cur_cnt_ + << ", total count: " << ref_cnt_; + } + + inline size_t CurCnt() const { return cur_cnt_; } inline size_t RefCnt() const { return ref_cnt_; } + inline bool SumGradCompleted() const { + return cur_cnt_ == ref_cnt_ || ref_cnt_ == 1; + } + + std::shared_ptr<VariableWrapper>& InnerVar() { return inner_var_; } + + // return the var that will be calculated in this graph + VariableWrapper* Var() { + return inner_var_ != nullptr ? inner_var_.get() : var_; + } + + inline bool HasInnerVar() const { return inner_var_ != nullptr; } + /* Hook related methods */ inline bool HasPostHooks() const { return !post_hooks_.expired(); } @@ -54,6 +109,11 @@ class GradientAccumulator { post_hooks_ = hooks; } } + // void CallHooks(){} + // ** inner_var_ ** + + // function that Sum Gradient with Previous Graph + void AccumulateGrad(); // call backward post hooks, such as reduce hook void CallBackwardPostHooks() { @@ -71,8 +131,11 @@ class GradientAccumulator { protected: VariableWrapper* var_; + // NOTE: only gradient accumulater of leaf tensor should hold + // inner_var_, So not hold it by other shared pointer. + std::shared_ptr<VariableWrapper> inner_var_; size_t ref_cnt_{0}; - + size_t cur_cnt_{0}; std::weak_ptr<LeafVarHookPipeline> post_hooks_; }; @@ -80,32 +143,16 @@ class EagerGradientAccumulator : public GradientAccumulator { public: using GradientAccumulator::GradientAccumulator; - void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id, - bool unchange_input) override; - - private: - inline bool AccumulateCompleted() const { return cur_cnt_ == ref_cnt_; } - - void IncreaseCurCnt() { - ++cur_cnt_; - VLOG(3) << "IncreaseCurCnt: cur_cnt " << cur_cnt_ << ", ref_cnt " - << ref_cnt_; - // After all tmp gradient being accumulated to grad var, run hooks - if (AccumulateCompleted() && HasPostHooks()) { - CallBackwardPostHooks(); - } - } - - private: - size_t cur_cnt_{0}; + void SumGrad(std::shared_ptr<VariableWrapper> var, size_t trace_id, + bool unchange_input) override; }; class SortedGradientAccumulator : public GradientAccumulator { public: using GradientAccumulator::GradientAccumulator; - void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id, - bool unchange_input) override; + void SumGrad(std::shared_ptr<VariableWrapper> var, size_t trace_id, + bool unchange_input) override; private: struct SavedVarInfo { diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index eaf9986b200af8d6b1bd7a2da2c957415838abe0..6f490c3c2bed8fd8fb97cdb4040d424f9afaf660 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -215,6 +215,10 @@ void VarBase::ClearGradient() { #endif } } + // TODO(zhouwei): It's better to free memory of grad by grad_t->claer. + // But will have some bug on mac CPU of yolov3 model, why? + // After fix this bug, function SetIsEmpty() isn't need + grad_var_->SharedVar()->SetIsEmpty(true); } } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 9a587fd6d6c43bc9ae1ad4c3005c00b0d7f3eee8..1a974ab346ea10762291a11c1851c8e453b7738e 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -146,6 +146,8 @@ class VarBase { bool OverridedStopGradient() const { return var_->OverridedStopGradient(); } + bool IsLeaf() const { return var_->IsLeaf(); } + void InnerSetOverridedStopGradient(bool stop_gradient) { if (var_->InnerOverridedStopGradient() == -1) { var_->InnerSetOverridedStopGradient(stop_gradient); @@ -182,6 +184,10 @@ class VarBase { std::string GradVarName() { return framework::GradVarName(Name()); } + void SetGraphIsFreed(bool free) { graph_is_free_ = free; } + + const bool& GraphIsFreed() const { return graph_is_free_; } + void SetType(framework::proto::VarType::Type type) { var_->SetType(type); } framework::proto::VarType::Type Type() const { return var_->Type(); } @@ -220,6 +226,8 @@ class VarBase { */ std::shared_ptr<GradOpNode> grad_node_; + bool graph_is_free_ = false; + mutable size_t copied_counter_ = 0; static ThreadSafeNameSet name_set_; diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 5c717835e5cc2042a7a3fdd8c51aa6eeff1fc523..d8f828ede25ff23ecbb9d8395329f4991117f40d 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -367,7 +367,7 @@ class GradientAccumulationInfo { "Reference count overflows, this may be a bug")); *is_finished = (cur_ref_cnt_ == total_ref_cnt_); - accumulator_->Add(grad_var_partial, trace_id, unchange_input); + accumulator_->SumGrad(grad_var_partial, trace_id, unchange_input); if (create_graph_) { VLOG(10) << "Store partial grad grad for double grad " diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index a8de1e6b0392685579a31138be2412871ab1ba5d..782f6dad58d46eb539186e331108c74ddb728ac9 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -7,7 +7,7 @@ else() endif(WIN32) -cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator) +cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows selected_rows_functor gradient_accumulator math_function) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) diff --git a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc index 49bc24edbad60fe04f6691ed95c55ea907d8d739..c394ce07df3c3938087e9b8afe4d31bceec53a38 100644 --- a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc +++ b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/gradient_accumulator.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/math/math_function.h" namespace imperative = paddle::imperative; namespace platform = paddle::platform; @@ -263,6 +264,9 @@ static void TestGradientAccumulatorTestUnchangeInput( for (auto use_tensor1 : use_tensors) { for (auto use_tensor2 : use_tensors) { + /** g_accum1 && g_accum2: has not been initialized + * test accumulate on this graph + */ auto g_var1 = std::make_shared<VariableWrapper>("g_var1"); g_var1->SetOverridedStopGradient(false); auto g_accum1 = CreateAccumulator(g_var1, sort_gradient); @@ -278,8 +282,14 @@ static void TestGradientAccumulatorTestUnchangeInput( auto var1 = create_var(use_tensor1); auto var_wrapper1_1 = std::make_shared<VariableWrapper>("tmp1_1"); auto var_wrapper2_1 = std::make_shared<VariableWrapper>("tmp2_1"); + + ASSERT_EQ(var_wrapper1_1->IsEmpty(), true); CopyVar(var1, var_wrapper1_1->MutableVar()); + ASSERT_EQ(var_wrapper1_1->IsEmpty(), false); + + ASSERT_EQ(var_wrapper2_1->IsEmpty(), true); CopyVar(var1, var_wrapper2_1->MutableVar()); + ASSERT_EQ(var_wrapper2_1->IsEmpty(), false); auto var2 = create_var(use_tensor2); auto var_wrapper1_2 = std::make_shared<VariableWrapper>("tmp1_2"); @@ -287,15 +297,59 @@ static void TestGradientAccumulatorTestUnchangeInput( CopyVar(var2, var_wrapper1_2->MutableVar()); CopyVar(var2, var_wrapper2_2->MutableVar()); - g_accum1->Add(var_wrapper1_1, 0, false); - g_accum1->Add(var_wrapper1_2, 1, false); - - g_accum2->Add(var_wrapper2_1, 0, true); - g_accum2->Add(var_wrapper2_2, 1, true); + // g_accum1: inner_var_ = var1 + var2 + g_accum1->SumGrad(var_wrapper1_1, 0, false); + g_accum1->SumGrad(var_wrapper1_2, 1, false); + ASSERT_EQ(g_accum1->CurCnt(), g_accum1->RefCnt()); + ASSERT_TRUE(g_accum1->SumGradCompleted()); + // g_accum1: inner_var_ -> var_ + g_accum1->AccumulateGrad(); + + // g_accum2: inner_var_ = var1 + var2 + g_accum2->SumGrad(var_wrapper2_1, 0, true); + g_accum2->SumGrad(var_wrapper2_2, 1, true); + ASSERT_EQ(g_accum2->CurCnt(), g_accum2->RefCnt()); + ASSERT_TRUE(g_accum2->SumGradCompleted()); + // g_accum2: inner_var_ -> var_ + g_accum2->AccumulateGrad(); ASSERT_TRUE(IsEqualVar(var_wrapper2_1->Var(), var1)); ASSERT_TRUE(IsEqualVar(var_wrapper2_2->Var(), var2)); ASSERT_TRUE(IsEqualVar(g_var1->Var(), g_var2->Var())); + + /** g_accum3 && g_accum4: has been initialized + * test accumulate on previous graph + */ + auto var3 = create_var(use_tensor1); + auto var_wrapper3_3 = std::make_shared<VariableWrapper>("tmp1_3"); + auto var_wrapper4_3 = std::make_shared<VariableWrapper>("tmp2_3"); + var_wrapper3_3->SetOverridedStopGradient(false); + var_wrapper4_3->SetOverridedStopGradient(false); + CopyVar(var3, var_wrapper3_3->MutableVar()); + CopyVar(var3, var_wrapper4_3->MutableVar()); + + auto g_accum3 = CreateAccumulator(var_wrapper3_3, sort_gradient); + g_accum3->IncreaseRefCnt(); + auto g_accum4 = CreateAccumulator(var_wrapper4_3, sort_gradient); + g_accum4->IncreaseRefCnt(); + + auto var4 = create_var(use_tensor2); + auto var_wrapper3_4 = std::make_shared<VariableWrapper>("tmp1_4"); + auto var_wrapper4_4 = std::make_shared<VariableWrapper>("tmp2_4"); + CopyVar(var4, var_wrapper3_4->MutableVar()); + CopyVar(var4, var_wrapper4_4->MutableVar()); + + g_accum3->SumGrad(var_wrapper3_4, 0, false); + ASSERT_TRUE(g_accum3->SumGradCompleted()); + // g_accum4: var_(var_wrapper3_3) + inner_var_ -> var_ + g_accum3->AccumulateGrad(); + + g_accum4->SumGrad(var_wrapper4_4, 0, false); + ASSERT_TRUE(g_accum4->SumGradCompleted()); + // g_accum4: var_(var_wrapper4_3) + inner_var_ -> var_ + g_accum4->AccumulateGrad(); + + ASSERT_TRUE(IsEqualVar(var_wrapper3_3->Var(), var_wrapper4_3->Var())); } } } diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index df972035ae377af3dd64f10d6181ebba749df710..fec12f2da13c1ef724d3b94dbab86e4bae1e6cba 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -68,10 +68,50 @@ class VariableWrapper { } } + bool IsLeaf() const { + if (OverridedStopGradient()) { + return true; + } + if (HasGradVar() && !GetGradVar()->HasGradNode()) { + return true; + } + return false; + } + + bool IsLeafGrad() const { + if (!HasGradVar() && !HasGradNode() && !OverridedStopGradient()) { + return true; + } + return false; + } + void SetPersistable(bool persistable) { persistable_ = persistable; } bool Persistable() const { return persistable_; } + bool IsEmpty() const { + bool is_empty = true; + if (var_.IsInitialized()) { + const framework::Tensor* tensor = nullptr; + if (var_.IsType<framework::LoDTensor>()) { + tensor = &(var_.Get<framework::LoDTensor>()); + } else if (var_.IsType<framework::SelectedRows>()) { + tensor = &(var_.Get<framework::SelectedRows>().value()); + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Only support LoDTensor and SelectedRows for gradient var")); + } + if (tensor && tensor->IsInitialized()) { + is_empty = false; + } + } + return is_empty || is_empty_; + } + + // TODO(zhouwei): fix Tensor.clear_gradient() bug, function SetIsEmpty() isn't + // need + void SetIsEmpty(bool is_empty) { is_empty_ = is_empty; } + const std::string& Name() const { return name_; } void SetName(const std::string& name) { name_ = name; } @@ -96,6 +136,8 @@ class VariableWrapper { bool HasGradNode() const { return !grad_node_.expired(); } + bool HasGradVar() const { return !grad_var_.expired(); } + framework::proto::VarType::Type DataType() const { const framework::Tensor* tensor = nullptr; if (var_.IsInitialized()) { @@ -265,6 +307,10 @@ class VariableWrapper { std::weak_ptr<VariableWrapper> grad_var_; std::weak_ptr<GradOpNode> grad_node_; + // TODO(zhouwei): fix bug of Tensor.clear_gradient(), function SetIsEmpty() + // isn't need + bool is_empty_{false}; + // NOTE: only grad var can hold hooks now // only interior var can hold interior hooks std::shared_ptr<InteriorVarHookPipeline> interior_hooks_; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index d675782a483d1465881c8579e647a80586322fcc..3510c9d152c83415922297b3def94ec25151c54e 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -670,7 +670,6 @@ void BindImperative(py::module *m_ptr) { return TensorToPyArray(tensor, true); }, R"DOC( - Returns a numpy array shows the value of current Tensor. Returns: @@ -689,7 +688,6 @@ void BindImperative(py::module *m_ptr) { data = paddle.to_tensor(data) x = linear(data) print(x.numpy()) - )DOC") .def("detach", [](const imperative::VarBase @@ -1080,6 +1078,35 @@ void BindImperative(py::module *m_ptr) { return std::vector<int>(); } }) + .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, + R"DOC( + Whether a Tensor is leaf Tensor. + + For the Tensor whose stop_gradient is ``True`` , it will be leaf Tensor. + + For the Tensor whose stop_gradient is ``False`` , it will be leaf Tensor too if it is created by user. + + Returns: + bool: Whether a Tensor is leaf Tensor. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor(1.) + print(x.is_leaf) # True + + x = paddle.to_tensor(1., stop_gradient=True) + y = x + 1 + print(x.is_leaf) # True + print(y.is_leaf) # True + + x = paddle.to_tensor(1., stop_gradient=False) + y = x + 1 + print(x.is_leaf) # True + print(y.is_leaf) # False + )DOC") .def_property_readonly( "place", [](imperative::VarBase &self) { return self.Place(); }, py::return_value_policy::copy) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index ab5135645a01b70ff509e8175f95ba42f59a0745..6a59e33285c4a8041a74d3d33a5b6495875bcfc9 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -133,11 +133,12 @@ def monkey_patch_varbase(): @framework.dygraph_only def backward(self, retain_graph=False): """ - **Notes**: - **This API is ONLY available in Dygraph mode** - Run backward of current Graph which starts from current Tensor. + The new gradient will accumulat on previous gradient. + + You can clear gradient by ``Tensor.clear_grad()`` . + Args: retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter @@ -150,21 +151,20 @@ def monkey_patch_varbase(): Examples: .. code-block:: python - import numpy as np - import paddle - paddle.disable_static() - - x = np.ones([2, 2], np.float32) - inputs = [] - for _ in range(10): - tmp = paddle.to_tensor(x) - # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since - # there is no one need gradient on it. - tmp.stop_gradient=False - inputs.append(tmp) - ret = paddle.add_n(inputs) - loss = paddle.sum(ret) - loss.backward() + x = paddle.to_tensor(5., stop_gradient=False) + for i in range(5): + y = paddle.pow(x, 4.0) + y.backward() + print("{}: {}".format(i, x.grad)) + # 0: [500.] + # 1: [1000.] + # 2: [1500.] + # 3: [2000.] + # 4: [2500.] + + x.clear_grad() + print("{}".format(x.grad)) + # 0. """ if framework.in_dygraph_mode(): @@ -181,31 +181,21 @@ def monkey_patch_varbase(): @framework.dygraph_only def gradient(self): """ - **Notes**: - **This API is ONLY available in Dygraph mode** - - Get the Gradient of Current Variable + Get the Gradient of Current Tensor. Returns: - ndarray: Numpy value of the gradient of current Variable + ndarray: Numpy value of the gradient of current Tensor Examples: .. code-block:: python - import paddle.fluid as fluid - import numpy as np + import paddle - x = np.ones([2, 2], np.float32) - with fluid.dygraph.guard(): - inputs2 = [] - for _ in range(10): - tmp = fluid.dygraph.base.to_variable(x) - tmp.stop_gradient=False - inputs2.append(tmp) - ret2 = fluid.layers.sums(inputs2) - loss2 = fluid.layers.reduce_sum(ret2) - loss2.backward() - print(loss2.gradient()) + x = paddle.to_tensor(5., stop_gradient=False) + y = paddle.pow(x, 4.0) + y.backward() + print("grad of x: {}".format(x.grad)) + # [500.] """ if self._grad_ivar() is None: @@ -226,6 +216,12 @@ def monkey_patch_varbase(): return self.gradient() + def clear_grad(self): + """ + The alias of clear_gradient(). + """ + self.clear_gradient() + @property def inplace_version(self): """ @@ -284,10 +280,10 @@ def monkey_patch_varbase(): for method_name, method in ( ("__bool__", __bool__), ("__nonzero__", __nonzero__), ("_to_static_var", _to_static_var), ("set_value", set_value), - ("block", block), ("backward", backward), ("grad", grad), - ("inplace_version", inplace_version), ("gradient", gradient), - ("__str__", __str__), ("__repr__", __str__), ("__module__", "paddle"), - ("__name__", "Tensor")): + ("block", block), ("backward", backward), ("clear_grad", clear_grad), + ("inplace_version", inplace_version), ("grad", grad), + ("gradient", gradient), ("__str__", __str__), ("__repr__", __str__), + ("__module__", "paddle"), ("__name__", "Tensor")): setattr(core.VarBase, method_name, method) # patch math methods for varbase diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index f3c4984e29e7839145e4074c0214cd717cc634af..d4468f0193b7de1aa706b07cbfcc49cfc5b4478a 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -874,6 +874,8 @@ class Optimizer(object): def clear_gradients(self): """ Clear the gradients of all optimized parameters for model. + + If not, new gradient will accumulat on previous gradient. Returns: None diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 514154f1dd7014d7b757ea8dd2513e744fe2bfb6..d2f143d7ad44035ff2241bd6d830b88a6a3b548a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -478,6 +478,114 @@ class TestImperative(unittest.TestCase): self.assertEqual(mlp._linear2, sublayers[1]) self.assertEqual(len(sublayers), 2) + def test_gradient_accumulation(self): + def test_single_api(sort_sum_gradient): + fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient}) + x = paddle.to_tensor(5., stop_gradient=False) + for i in range(10): + y = paddle.pow(x, 4.0) + y.backward() + print(x.grad) + self.assertEqual(x.grad, (i + 1) * 500) + x.clear_gradient() + self.assertEqual(x.grad, 0.) + for i in range(5): + y = paddle.pow(x, 4.0) + y.backward() + print(x.grad) + self.assertEqual(x.grad, (i + 1) * 500) + + def test_simple_net(sort_sum_gradient): + fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient}) + x = paddle.to_tensor(5., stop_gradient=False) + y = paddle.to_tensor(2., stop_gradient=False) + z = paddle.to_tensor(3., stop_gradient=False) + + def fun(x, y, z): + loss1 = x * x * y + loss2 = x * z + dx = paddle.grad([loss1], x, create_graph=True)[0] + # loss = x*x*y + x*z + 2*x*y + loss = loss1 + loss2 + dx + return loss + + loss = fun(x, y, z) + loss.backward(retain_graph=True) + # x.grad = 2*x*y + z + 2*y = 27 + self.assertTrue(np.array_equal(x.grad, [27])) + + loss.backward(retain_graph=True) + self.assertTrue(np.array_equal(x.grad, [54])) + + loss.backward() + self.assertTrue(np.array_equal(x.grad, [81])) + + with self.assertRaises(RuntimeError): + loss.backward() + + loss1 = x * x * y + loss2 = x * z + dx = paddle.grad([loss1], x, create_graph=True)[0] + loss = loss1 + loss2 + dx + loss.backward() + self.assertTrue(np.array_equal(dx.grad, [1])) + self.assertTrue(np.array_equal(x.grad, [108])) + + def test_mlp(sort_sum_gradient): + fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient}) + input_size = 5 + paddle.seed(1) + mlp1 = MLP(input_size=input_size) + # generate the gradient of each step + mlp2 = MLP(input_size=input_size) + + expected_weight1_grad = np.zeros(mlp2._linear1.weight.shape) + expected_bias1_grad = np.zeros(mlp2._linear1.bias.shape) + expected_weight2_grad = np.zeros(mlp2._linear2.weight.shape) + expected_bias2_grad = np.zeros(mlp2._linear2.bias.shape) + + for batch_id in range(24): + x = paddle.uniform([10, input_size]) + detach_x = x.detach() + clear_loss = mlp2(detach_x) + clear_loss.backward() + expected_weight1_grad = expected_weight1_grad + mlp2._linear1.weight.grad + expected_bias1_grad = expected_bias1_grad + mlp2._linear1.bias.grad + expected_weight2_grad = expected_weight2_grad + mlp2._linear2.weight.grad + expected_bias2_grad = expected_bias2_grad + mlp2._linear2.bias.grad + + loss = mlp1(x) + loss.backward() + + self.assertTrue(np.array_equal(loss.grad, [1])) + self.assertTrue( + np.allclose(mlp1._linear1.weight.grad, + expected_weight1_grad)) + self.assertTrue( + np.allclose(mlp1._linear1.bias.grad, expected_bias1_grad)) + self.assertTrue( + np.allclose(mlp1._linear2.weight.grad, + expected_weight2_grad)) + self.assertTrue( + np.allclose(mlp1._linear2.bias.grad, expected_bias2_grad)) + + mlp2.clear_gradients() + self.assertTrue(np.array_equal(clear_loss.grad, [1])) + if ((batch_id + 1) % 8) == 0: + mlp1.clear_gradients() + expected_weight1_grad = np.zeros(mlp2._linear1.weight.shape) + expected_bias1_grad = np.zeros(mlp2._linear1.bias.shape) + expected_weight2_grad = np.zeros(mlp2._linear2.weight.shape) + expected_bias2_grad = np.zeros(mlp2._linear2.bias.shape) + + with fluid.dygraph.guard(): + test_single_api(False) + test_single_api(True) + test_simple_net(False) + test_simple_net(True) + test_mlp(False) + test_mlp(True) + def test_dygraph_vs_static(self): np_inp1 = np.random.rand(4, 3, 3) np_inp2 = np.random.rand(4, 3, 3) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index 8f3116f65351447457ca13a3dfa978a481ff56f3..e41960f6b47c29dccdb0709ce37f5f26f90e7fbd 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -214,7 +214,7 @@ class TestDygraphDoubleGrad(TestCase): self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + loss.backward(retain_graph=True) x_grad_actual = x.gradient() x_grad_expected = (2.0 / float(numel) * @@ -222,6 +222,16 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2 / float(numel))).astype('float32') self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + for i in range(5): + loss.backward(retain_graph=True) + x_grad_actual = x.gradient() + x_grad_expected = (i + 2) * (2.0 / float(numel) * ( + x_np + dx_expected * + (x_np > 0) * 2 / float(numel))).astype('float32') + print(x_grad_actual) + print(x_grad_expected) + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + @dygraph_guard def test_example_with_gradient_accumulation_and_no_grad_vars(self): x = random_var(self.shape) diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 6ee7940e174ae4208daee3e93895f072e364f2df..40a1c8def5d64009c0b63b90cefa04e20174d254 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -457,6 +457,7 @@ class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase): loss = paddle.mean(out) loss.backward() momentum.minimize(loss) + linear.clear_gradients() def __test_vs(self, place=fluid.CPUPlace()): paddle.disable_static(place=place) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 1f101a17da986f8082bedd751b7cdb1f23685368..86ba5a96b8d39b05a593f26b4fbad726a77f24cf 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -198,6 +198,32 @@ class TestVarBase(unittest.TestCase): var = fluid.dygraph.to_variable(t) self.assertTrue(np.array_equal(t, var.numpy())) + def test_leaf_tensor(self): + with fluid.dygraph.guard(): + x = paddle.to_tensor(np.random.uniform(-1, 1, size=[10, 10])) + self.assertTrue(x.is_leaf) + y = x + 1 + self.assertTrue(y.is_leaf) + + x = paddle.to_tensor( + np.random.uniform( + -1, 1, size=[10, 10]), stop_gradient=False) + self.assertTrue(x.is_leaf) + y = x + 1 + self.assertFalse(y.is_leaf) + + linear = paddle.nn.Linear(10, 10) + input = paddle.to_tensor( + np.random.uniform( + -1, 1, size=[10, 10]).astype('float32'), + stop_gradient=False) + self.assertTrue(input.is_leaf) + + out = linear(input) + self.assertTrue(linear.weight.is_leaf) + self.assertTrue(linear.bias.is_leaf) + self.assertFalse(out.is_leaf) + def test_detach(self): with fluid.dygraph.guard(): x = paddle.to_tensor(1.0, dtype="float64", stop_gradient=False) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 295821a93cd3f8587b95cf4e0b4d6bb157d5cdef..1cfc0b66e7b671429b4dbbb17a3e586931a6a6e7 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -793,6 +793,8 @@ class Optimizer(object): def clear_grad(self): """ Clear the gradients of all optimized parameters for model. + + If not, new gradient will accumulat on previous gradient. Returns: None