// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" namespace paddle { namespace imperative { enum TracedVarRole { kForward = 0, kBackward = 1 }; template class TracedVarList : public std::vector> { private: using BaseClass = std::vector>; public: using BaseClass::BaseClass; }; class GradOpBaseMakerBase { public: explicit GradOpBaseMakerBase(const std::string& type, const NameVarBaseMap& var_base_map_in, const NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs) : type_(type), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out), attrs_(attrs) {} virtual ~GradOpBaseMakerBase() = default; virtual std::shared_ptr operator()() const = 0; TracedVarList InputGrad( const std::string& name, bool drop_empty_grad = true) const { return GetVarBaseList(name, /*is_input=*/true); } TracedVarList OutputGrad( const std::string& name) const { return GetVarBaseList(name, /*is_input=*/false); } TracedVarList Input( const std::string& name) const { return GetVarBaseList(name, /*is_input=*/true); } TracedVarList Output( const std::string& name) const { return GetVarBaseList(name, /*is_input=*/false); } static TracedVarList EmptyInput() { return {}; } static TracedVarList EmptyOutput() { return {}; } static TracedVarList EmptyOutputGrad() { return {}; } static TracedVarList EmptyInputGrad() { return {}; } std::vector InputNames() const { std::vector vec_temp; vec_temp.reserve(var_base_map_in_.size()); for (auto& it : var_base_map_in_) { vec_temp.emplace_back(it.first); } return vec_temp; } std::vector OutputNames() const { std::vector vec_temp; vec_temp.reserve(var_base_map_out_.size()); for (auto& it : var_base_map_out_) { vec_temp.emplace_back(it.first); } return vec_temp; } const framework::AttributeMap& Attrs() const { return attrs_; } const framework::Attribute& GetAttr(const std::string& name) const { auto it = attrs_.find(name); PADDLE_ENFORCE_EQ( it != attrs_.end(), true, platform::errors::NotFound( "Cannot find attribute [%s] in operator [%s]", name, type_)); return it->second; } template inline const T& Attr(const std::string& name) const { return boost::get(GetAttr(name)); } const std::string& ForwardOpType() const { return type_; } protected: bool HasInput(const std::string& name) const { return var_base_map_in_.count(name) > 0; } bool HasOutput(const std::string& name) const { return var_base_map_out_.count(name) > 0; } static std::shared_ptr NewGradNode() { return std::make_shared(); } private: template TracedVarList GetVarBaseList(const std::string& name, bool is_input) const { const auto& data_map = is_input ? var_base_map_in_ : var_base_map_out_; auto iterator = data_map.find(name); TracedVarList vec_temp; if (iterator != data_map.end()) { vec_temp.reserve(iterator->second.size()); bool is_valid = false; for (auto& var_base_temp : iterator->second) { if (!var_base_temp) { vec_temp.emplace_back(); continue; } if (kRole == TracedVarRole::kBackward) { if (!var_base_temp->HasGradVar()) { VLOG(6) << "GradVarBase of var " << var_base_temp->Name() << " in OP " << type_ << " is null"; var_base_temp->MutableGradVarBase(); } auto grad_var_base_tmp = var_base_temp->GradVarBase(); if (!is_input) { auto* tensor = grad_var_base_tmp->MutableVar() ->GetMutable(); tensor->Resize( var_base_temp->Var().Get().dims()); } vec_temp.emplace_back(grad_var_base_tmp); } else { vec_temp.emplace_back(var_base_temp); } is_valid = true; } if (!is_valid) { vec_temp.clear(); } } return vec_temp; } private: const std::string& type_; const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; const framework::AttributeMap& attrs_; }; class TracedGradOp { DISABLE_COPY_AND_ASSIGN(TracedGradOp); public: explicit TracedGradOp(const std::shared_ptr& node) : node_(node), op_(&(node->emplace_back())) {} ~TracedGradOp() { if (UNLIKELY(op_->GetOutsMap().empty())) { node_->pop_back(); } else { op_->CheckAttrs(); } } template void SetInput(const std::string& name, const TracedVarList& vars) { if (vars.empty()) { return; } if (kRole == TracedVarRole::kBackward) { for (auto& var : vars) { if (var && !var->OverridedStopGradient()) { var->SetGradNode(node_); } } } auto var_wrappers = ToVarWrapperList(vars); if (!var_wrappers.empty()) { op_->SetInput(name, std::move(var_wrappers), kRole == TracedVarRole::kBackward); } } template void SetOutput(const std::string& name, const TracedVarList& vars) { if (vars.empty()) { return; } if (kRole == TracedVarRole::kBackward) { if (vars.size() == 1 && vars.front()->OverridedStopGradient()) { return; } else { for (auto& var : vars) { if (var && !var->OverridedStopGradient() && var->GradNode()) { node_->InsertGradPendingNode(var->GradNode()); } } } } auto var_wrappers = ToVarWrapperList(vars); if (!var_wrappers.empty()) { op_->SetOutput(name, std::move(var_wrappers), kRole == TracedVarRole::kBackward); } } void SetType(const std::string& type) { op_->SetType(type); } void SetAttrMap(const framework::AttributeMap& attrs) { return op_->SetAttrMap(attrs); } void SetAttr(const std::string& name, const framework::Attribute& v) { op_->SetAttr(name, v); } bool HasAttr(const std::string& name) const { return op_->HasAttr(name); } const framework::Attribute& GetAttr(const std::string& name) const { return op_->GetAttr(name); } template inline const T& Attr(const std::string& name) const { return op_->Attr(name); } private: template static std::vector> ToVarWrapperList( const std::vector>& vars) { std::vector> result; result.reserve(vars.size()); bool has_valid = false; for (auto& var : vars) { if (UNLIKELY(!var || (kRole == TracedVarRole::kBackward && var->OverridedStopGradient()))) { result.emplace_back(); } else { result.emplace_back(var->SharedVar()); has_valid = true; } } if (!has_valid) { result.clear(); } return result; } private: const std::shared_ptr& node_; OpBase* op_; }; } // namespace imperative } // namespace paddle