// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/imperative/tracer.h" #include #include #include #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace imperative { struct OpBaseCmp { bool operator()(OpBase* first, OpBase* second) { return first->id() > second->id(); } }; static std::vector> CreateGradOpBases( const OpBase* fw_op_base, const NameVarBaseMap& in, const NameVarBaseMap& out) { if (fw_op_base->Info().dygraph_grad_op_maker_) { return fw_op_base->Info().dygraph_grad_op_maker_(fw_op_base, in, out); } else { return {}; } } static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) { for (const auto& name_pair : outs) { for (const auto& vb : name_pair.second) { VLOG(6) << "Set output: " << vb->Name() << "'s OverridedStopGradient as " << generate_grad; vb->InnerSetOverridedStopGradient(generate_grad); } } } void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, const NameVarBaseMap& outs, framework::AttributeMap attrs, const platform::Place& place, bool trace_backward) { platform::RecordEvent event(type); VLOG(1) << "Trace Op: " << type; size_t op_id = GenerateUniqueId(); auto op = OpBase::Create(op_id, type, ins, outs, std::move(attrs), place); op->Run(ins, outs); if (enable_program_desc_tracing_) { VLOG(5) << "Trace op " << type << " into ProgramDesc"; program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs()); } if (ComputeRequiredGrad(ins, outs, trace_backward)) { TraceBackward(op, ins, outs); } else { VLOG(3) << "No Grad to track for Op: " << type; } } bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs, bool trace_backward) { if (!trace_backward) return false; for (const auto& name_pair : ins) { for (const auto& var_base : name_pair.second) { if (!var_base->OverridedStopGradient()) { VLOG(6) << "Find out input: " << var_base->Name() << "'s GeneratedGrad is True"; PassStopGradient(outs, var_base->OverridedStopGradient()); return true; } } } return false; } void Tracer::TraceBackward(const std::shared_ptr& fwd_op, const NameVarBaseMap& ins, const NameVarBaseMap& outs) { // grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) -> // in_var_name/out_var_name std::unordered_map grad_to_var; // Get grad_op_desc using fwd_op_desc std::vector> grad_op_bases_ = CreateGradOpBases(fwd_op.get(), ins, outs); size_t grad_op_num = grad_op_bases_.size(); std::set set_input_vars; for (auto& fwd_in_it : ins) { for (auto& var_base_it : fwd_in_it.second) { set_input_vars.insert(var_base_it.get()); } } for (auto& fwd_out_it : outs) { for (auto& var_base_it : fwd_out_it.second) { set_input_vars.insert(var_base_it.get()); } } for (size_t i = 0; i < grad_op_num; ++i) { size_t trace_id = fwd_op->id(); std::shared_ptr grad_op = std::move(grad_op_bases_[i]); grad_op->SetId(trace_id); grad_op->SetPlace(fwd_op->place()); grad_op->CreateOperatorBase(); auto& grad_in = *(grad_op->GetMutableInsMap()); auto& grad_out = *(grad_op->GetMutableOutsMap()); for (auto& grad_in_it : grad_in) { for (auto& var_base_it : grad_in_it.second) { if (set_input_vars.count(var_base_it.get()) == 0) { var_base_it->AddGradOps(grad_op); engine_->InsertGradVar(var_base_it.get()); } } } std::set visited_preceding_ops; for (auto& grad_out_it : grad_out) { bool flag_clear_list = false; for (auto& var_base_it : grad_out_it.second) { if ((!var_base_it->OverridedStopGradient()) || (grad_out_it.second.size() > 1)) { auto preceding_ops = var_base_it->GradOps(); if (!preceding_ops.empty()) { for (const auto& op : preceding_ops) { visited_preceding_ops.insert(op); } } } else { flag_clear_list = true; } } if (flag_clear_list) { grad_out_it.second.clear(); } } std::vector vec_preceding_ops(visited_preceding_ops.begin(), visited_preceding_ops.end()); grad_op->SetGradPendingOps(std::move(vec_preceding_ops)); // this OpBase* is just used to manage op's life time engine_->InsertOp(grad_op.get(), grad_op); } } } // namespace imperative } // namespace paddle