// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" namespace paddle { namespace imperative { class GradOpBaseMakerBase { public: explicit GradOpBaseMakerBase(const OpBase* fw_op_base, const NameVarBaseMap& var_base_map_in, const NameVarBaseMap& var_base_map_out) : fw_op_base_(fw_op_base), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out) {} virtual ~GradOpBaseMakerBase() = default; virtual std::vector> operator()() const = 0; std::vector> InputGrad( const std::string& name, bool drop_empty_grad = true) const { return GetVarBaseList(name, true, true); } std::vector> OutputGrad( const std::string& name) const { return GetVarBaseList(name, true, false); } std::vector> Input(const std::string name) const { return GetVarBaseList(name, false, true); } std::vector> Output(const std::string& name) const { return GetVarBaseList(name, false, false); } std::vector> Empty() const { return {}; } std::vector InputNames() const { std::vector vec_temp; vec_temp.reserve(var_base_map_in_.size()); for (auto& it : var_base_map_in_) { vec_temp.emplace_back(it.first); } return vec_temp; } std::vector OutputNames() const { std::vector vec_temp; vec_temp.reserve(var_base_map_out_.size()); for (auto& it : var_base_map_out_) { vec_temp.emplace_back(it.first); } return vec_temp; } const std::unordered_map& Attrs() const { return fw_op_base_->Attrs(); } const framework::Attribute& GetAttr(const std::string& name) const { auto& map = fw_op_base_->Attrs(); auto it = map.find(name); PADDLE_ENFORCE(it != map.end(), "Cannot find attribute [%s] in operator [%s]", name, fw_op_base_->Type()); return it->second; } template inline const T& Attr(const std::string& name) const { return boost::get(GetAttr(name)); } std::string ForwardOpType() const { return fw_op_base_->Type(); } protected: bool HasInput(const std::string& name) const { auto it = var_base_map_in_.find(name); return it != var_base_map_in_.end(); } bool HasOutput(const std::string name) const { auto it = var_base_map_out_.find(name); return it != var_base_map_out_.end(); } private: std::vector> GetVarBaseList(const std::string& name, bool is_grad, bool is_input) const { const NameVarBaseMap& data_map = is_input ? var_base_map_in_ : var_base_map_out_; auto iterator = data_map.find(name); std::vector> vec_temp; if (iterator != data_map.end()) { vec_temp.reserve(iterator->second.size()); for (auto& var_base_temp : iterator->second) { if (is_grad) { PADDLE_ENFORCE_NOT_NULL(var_base_temp->GradVarBase(), "VarBase grad of OP [%s] should not be null", fw_op_base_->Type()); auto grad_var_base_tmp = var_base_temp->GradVarBase(); auto* tensor = grad_var_base_tmp->MutableVar() ->GetMutable(); tensor->Resize( var_base_temp->Var().Get().dims()); vec_temp.emplace_back(grad_var_base_tmp); } else { vec_temp.emplace_back(var_base_temp); } } } return vec_temp; } private: const OpBase* fw_op_base_; const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; protected: std::vector grad_block_; }; } // namespace imperative } // namespace paddle