// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/imperative/basic_engine.h" #include #include #include #include #include #include #include #include #include #include "paddle/fluid/imperative/gradient_accumulator.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/profiler.h" DECLARE_bool(sort_sum_gradient); namespace paddle { namespace imperative { void BasicEngine::Init(VarBase* var, bool retain_graph) { retain_graph_ = retain_graph; init_node_ = var->GradVarBase()->GradNode(); PADDLE_ENFORCE_EQ(var->GradVarBase()->GraphIsFreed(), false, platform::errors::Unavailable( "%s trying to backward through the same graph a second " "time, but this graph have already been freed. Please " "specify Tensor.backward(retain_graph=True) when " "calling backward at the first time.", var->Name())); if (!retain_graph) { VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() << " because of retain_graph=False when calling backward"; var->GradVarBase()->SetGraphIsFreed(true); var->GradVarBase()->ClearGradNode(); } if (init_node_ == nullptr || var->OverridedStopGradient()) { VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " "stop_gradient=True: " << var->Name(); return; } VLOG(3) << "Init first node of backward"; PADDLE_ENFORCE_EQ( var->HasGradVar(), true, platform::errors::NotFound("Grad variable not exist for variable %s", var->Name())); auto& fwd_var = var->Var().Get(); auto* grad_var = var->GradVarBase()->MutableVar()->GetMutable(); VLOG(6) << "init loss grad:" << var->GradVarBase()->Name() << " as stop_gradient false"; var->GradVarBase()->InnerSetOverridedStopGradient(false); auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); grad_var->Resize(fwd_var.dims()); grad_var->mutable_data(fwd_var.place(), fwd_var.type()); operators::math::set_constant(*dev_ctx, grad_var, 1.0); } void BasicEngine::CheckBackwardInputs(const OpBase& op) { for (auto& pair : op.GetInsMap()) { if (!pair.second.IsGrad()) { continue; } for (auto& var : pair.second) { if (!var) { continue; } auto* inner_var = var->MutableVar(); framework::Tensor* tensor = nullptr; if (!inner_var->IsInitialized() || inner_var->IsType()) { tensor = inner_var->GetMutable(); } if (tensor && !tensor->IsInitialized()) { auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place()); // NOTE(zhiqiu): since grad variable is ungenerated, so the dtype is not // correct. var->DataType() returns the default dtype, which is float32. // Here, we use the type of the corresponding forward datatype. tensor->mutable_data(op.place(), var->ForwardDataType()); VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero with dtype " << framework::DataTypeToString(var->ForwardDataType()); operators::math::set_constant(*dev_ctx, tensor, 0.0); } } } } void BasicEngine::PrepareGradAccumulators( const OpBase& op, const std::vector>& grad_pending_nodes) { for (const auto& pair : op.GetOutsMap()) { if (!pair.second.IsGrad()) { continue; } for (const auto& var : pair.second) { if (!var) continue; if (!var->HasGradNode()) { auto& accumulator = accumulators_[var.get()]; if (!accumulator) { if (FLAGS_sort_sum_gradient) { accumulator.reset(new SortedGradientAccumulator(var.get())); } else { accumulator.reset(new EagerGradientAccumulator(var.get())); } } accumulator->IncreaseRefCnt(); VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" << var.get() << ") that don't have grad node with reference count " << accumulator->RefCnt(); if (var->HasLeafHooks()) { VLOG(3) << "Grad variable wrapper (" << var->Name() << ") has leaf grad hooks."; PADDLE_ENFORCE_NE( var->HasGradNode(), true, platform::errors::PermissionDenied( "Only leaf Tensor's gradient can append hook to " "Gradientaccumulator.")); accumulator->SetPostHooks(var->GetLeafHooks()); } } else { // Because Inplace op overwrites the grad_node of the input grad_var. So // only the information of grad_pending_node can be used to find the // grad_node of grad_var. bool find_grad_node_of_var = false; for (auto& grad_pending_node : grad_pending_nodes) { PADDLE_ENFORCE_NOT_NULL( grad_pending_node, platform::errors::NotFound("Grad pending node is nullptr.")); for (auto& grad_pending_op : *grad_pending_node) { VLOG(6) << "Determine whether var (" << var->Name() << ") is the input var of grad_pending_op (" << grad_pending_op.Type() << ")."; grad_pending_op.EnforceHasInOut(); for (const auto& grad_pending_op_ins_pair : grad_pending_op.GetInsMap()) { if (!grad_pending_op_ins_pair.second.IsGrad()) { continue; } for (const auto& pending_in_var : grad_pending_op_ins_pair.second) { if (var == pending_in_var) { VLOG(6) << "Var (" << var->Name() << ") is the input var of grad_pending_op (" << grad_pending_op.Type() << ")."; find_grad_node_of_var = true; break; } } if (find_grad_node_of_var) { break; } } } if (find_grad_node_of_var) { auto& accumulator = accumulators_with_grad_node_[grad_pending_node][var.get()]; if (!accumulator) { if (FLAGS_sort_sum_gradient) { accumulator.reset(new SortedGradientAccumulator(var.get())); } else { accumulator.reset(new EagerGradientAccumulator(var.get())); } } accumulator->IncreaseRefCnt(); VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" << var.get() << ") that has grad node with reference count " << accumulator->RefCnt(); break; } } PADDLE_ENFORCE_EQ( find_grad_node_of_var, true, platform::errors::NotFound( "No grad node corresponding to grad Tensor (%s) was found.", var->Name())); } } } } void BasicEngine::PrepareDeps() { PADDLE_ENFORCE_EQ( node_deps_.empty(), true, platform::errors::AlreadyExists("Op deps must be initialized.")); PADDLE_ENFORCE_EQ( accumulators_.empty(), true, platform::errors::AlreadyExists("Accumulators must be initialized.")); PADDLE_ENFORCE_EQ( accumulators_with_grad_node_.empty(), true, platform::errors::AlreadyExists("Accumulators must be initialized.")); std::queue q; std::unordered_set visited; q.push(init_node_.get()); visited.insert(init_node_.get()); while (!q.empty()) { auto* cur_node = q.front(); q.pop(); const auto& grad_pending_nodes = cur_node->GradPendingNodes(); for (auto& cur_op : *cur_node) { cur_op.EnforceHasInOut(); PrepareGradAccumulators(cur_op, grad_pending_nodes); } for (auto& grad_pending_node : grad_pending_nodes) { PADDLE_ENFORCE_NOT_NULL( grad_pending_node, platform::errors::NotFound("Grad pending node is nullptr.")); ++node_deps_[grad_pending_node.get()]; if (visited.count(grad_pending_node.get()) == 0) { visited.insert(grad_pending_node.get()); q.push(grad_pending_node.get()); } } } } void BasicEngine::Execute() { if (init_node_ == nullptr) { return; } PrepareDeps(); // Start execute Computation graph std::queue> q; q.push(std::move(init_node_)); size_t op_num = 0; while (!q.empty()) { auto shared_cur_node = std::move(q.front()); q.pop(); auto& inplace_grad_name_map = shared_cur_node->InplaceGradNameMap(); for (auto& cur_op : *shared_cur_node) { ++op_num; // CheckBackWardInput CheckBackwardInputs(cur_op); // Step 1: Run Backward OP auto& bwd_ins = cur_op.GetInsMap(); auto& bwd_outs = cur_op.GetOutsMap(); NameVarMap tmp_outs(bwd_outs); // 1. construct the temp output map, avoid to disrupt graph // 2. replace the element in the map by temp var, because a // var may be coresponding to several grad var in one op for (auto& pair : tmp_outs) { if (!pair.second.IsGrad()) { continue; } for (auto& var : pair.second) { if (!var) { continue; } std::unordered_map>::iterator iter; if (!var->HasGradNode()) { VLOG(10) << "Find gradient of var (" << var->Name() << ") with no grad_node."; iter = accumulators_.find(var.get()); PADDLE_ENFORCE_EQ( iter != accumulators_.end(), true, platform::errors::NotFound( "Cannot find gradient of variable %s", var->Name())); } else { bool flag_find_grad = false; VLOG(10) << "Find gradient of var (" << var->Name() << ") with grad_node."; for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) { const auto& iter_grad_node = accumulators_with_grad_node_.find(grad_pending_node); if (iter_grad_node != accumulators_with_grad_node_.end()) { iter = iter_grad_node->second.find(var.get()); if (iter != iter_grad_node->second.end()) { flag_find_grad = true; break; } } } PADDLE_ENFORCE_EQ( flag_find_grad, true, platform::errors::NotFound( "Cannot find gradient of variable %s", var->Name())); } // leaf_accumulators_ : hooks and accumulate-grad for leaf tensor if (var->IsLeafGrad()) { leaf_accumulators_.insert(iter->second.get()); if (iter->second->HasInnerVar()) { var = iter->second->InnerVar(); } } if (var->OverridedStopGradient() || iter->second->RefCnt() > 1) { auto tmp_var = std::make_shared(var->Name()); tmp_var->SetType(var->Type()); tmp_var->SetForwardDataType(var->ForwardDataType()); var = tmp_var; need_accu_var_list_.emplace_back(iter->second.get(), var); VLOG(10) << "create temporary var of " << var->Name() << " for sum gradient within this graph!"; } else if (!inplace_grad_name_map.empty() && inplace_grad_name_map.count(pair.first)) { // When calculate Inplace grad op, create a new output var. // If a tmp var has been created, there is no need to create it // again. for (auto& in_var : bwd_ins.at(inplace_grad_name_map.at(pair.first))) { if (in_var == var) { auto tmp_var = std::make_shared(var->Name()); tmp_var->SetType(var->Type()); tmp_var->SetForwardDataType(var->ForwardDataType()); inplace_output_grad_var_list_.emplace_back(var, tmp_var); var = tmp_var; VLOG(10) << "Inplace grad op does not use the Inplace " "strategy, a temporary output var (" << var->Name() << ") will be created."; break; } } } } } VLOG(4) << "Check whether there is any inplace operation affecting " "gradient calculation."; for (auto& pair : bwd_ins) { for (auto& var_wrapper : pair.second) { auto wrapper_version_snapshot = var_wrapper->InplaceVersionSnapshot(); auto tensor_version = var_wrapper->MutableVar()->CurrentInplaceVersion(); PADDLE_ENFORCE_EQ( tensor_version, wrapper_version_snapshot, platform::errors::PermissionDenied( "Tensor '%s' used in gradient computation in grad op '%s' " "has been " "modified by an inplace operation. " "Its version is %s but the expected version is %s. " "Please fix your code to void calling an inplace operator " "after using the Tensor which will used in gradient " "computation.", var_wrapper->Name(), cur_op.Type(), tensor_version, wrapper_version_snapshot)); VLOG(6) << " The version of Tensor '" << var_wrapper->Name() << "' is [ " << wrapper_version_snapshot << " ]"; } } { VLOG(3) << "Start to execute grad op " << cur_op.Type(); OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), cur_op.place()); } for (auto& pair : inplace_output_grad_var_list_) { *pair.first = std::move(*pair.second); } // Step 2: Sum Gradient of This graph for (auto& pair : need_accu_var_list_) { pair.first->SumGrad(std::move(pair.second), cur_op.id()); } // Step 3: Call Hooks && Sum Gradient with Pre-Graph && Call BackwardHooks for (auto* accumulator : leaf_accumulators_) { if (!accumulator->SumGradCompleted()) { continue; } // 1. Call Hooks for **inner_var_** // 2. Sum Gradient with Previous Graph accumulator->AccumulateGrad(); // 3. Call backward Hooks for **var_** if (accumulator->HasPostHooks()) { accumulator->CallBackwardPostHooks(); } } need_accu_var_list_.clear(); inplace_output_grad_var_list_.clear(); leaf_accumulators_.clear(); if (!retain_graph_) { VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; cur_op.ClearBackwardTrace(); } } // Step 3: Collect ready ops for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) { PADDLE_ENFORCE_NOT_NULL( grad_pending_node, platform::errors::NotFound("Grad pending node is nullptr.")); auto iter = node_deps_.find(grad_pending_node.get()); if (iter == node_deps_.end()) { continue; } if (--(iter->second) == 0) { q.push(grad_pending_node); } } } Clear(); VLOG(1) << "Backward op number: " << op_num; } void BasicEngine::Clear() { init_node_.reset(); node_deps_.clear(); accumulators_.clear(); accumulators_with_grad_node_.clear(); need_accu_var_list_.clear(); leaf_accumulators_.clear(); } } // namespace imperative } // namespace paddle