// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/platform/place.h" namespace paddle { namespace imperative { // TODO(zjl): to support py_func layer class OpBase { public: OpBase() = default; OpBase(const OpBase&) = delete; OpBase(OpBase&&) = default; OpBase& operator=(const OpBase&) = delete; OpBase& operator=(OpBase&&) = default; ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); } const std::string& Type() const { return op_->Type(); } const framework::AttributeMap& Attrs() const { return attrs_; } const framework::OpInfo& Info() const { return op_->Info(); } const framework::OperatorBase& InnerOp() const { return *op_; } void ClearBackwardTrace(); NameVarMap* GetMutableOutsMap() { return &outs_; } NameVarMap* GetMutableInsMap() { return &ins_; } const NameVarMap& GetInsMap() const { return ins_; } const NameVarMap& GetOutsMap() const { return outs_; } void SetType(const std::string& type); void CheckAttrs() { auto& info = op_->Info(); if (info.Checker() != nullptr) { info.Checker()->Check(&attrs_, true); } } void SetInput(const std::string& name, VariableWrapperList vars, bool is_grad) { auto& in_vars = ins_[name]; *(in_vars.MutableVarList()) = std::move(vars); in_vars.SetIsGrad(is_grad); } void SetOutput(const std::string& name, VariableWrapperList vars, bool is_grad) { auto& out_vars = outs_[name]; *(out_vars.MutableVarList()) = std::move(vars); out_vars.SetIsGrad(is_grad); } void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } void SetAttr(const std::string& name, const framework::Attribute& v) { attrs_[name] = v; } void SetBlockAttr(const std::string& name, framework::BlockDesc* block) { PADDLE_THROW(platform::errors::PermissionDenied( "SetBlockAttr is not support in dygraph OpBase")); } const framework::AttributeMap& Attrs() { return attrs_; } bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; } const framework::Attribute& GetAttr(const std::string& name) const { auto it = attrs_.find(name); PADDLE_ENFORCE_NE( it, attrs_.end(), platform::errors::NotFound("can not find attribute [%s]", name)); return it->second; } template inline const T& Attr(const std::string& name) const { return boost::get(GetAttr(name)); } size_t id() const { return id_; } void SetId(size_t id) { id_ = id; } const platform::Place& place() const { return place_; } void SetPlace(const platform::Place& place) { place_ = place; } static size_t GenerateUniqueId() { static std::atomic unique_id{0}; return unique_id.fetch_add(1); } static void Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const platform::Place& place); static void Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const platform::Place& place); private: NameVarMap ins_; NameVarMap outs_; framework::AttributeMap attrs_; std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; std::vector> backward_hooks_; }; class GradOpNode { public: GradOpNode() = default; void reserve(size_t size) { ops_.reserve(size); } size_t size() const { return ops_.size(); } bool empty() const { return ops_.empty(); } void clear() { ops_.clear(); } void pop_back() { ops_.pop_back(); } template OpBase& emplace_back(ARGS&&... args) { // NOLINT ops_.emplace_back(std::forward(args)...); return ops_.back(); } const OpBase& back() const { return ops_.back(); } OpBase& back() { return ops_.back(); } OpBase& operator[](size_t idx) { return ops_[idx]; } const OpBase& operator[](size_t idx) const { return ops_[idx]; } /* Iterator related */ using Iterator = std::vector::iterator; using ConstIterator = std::vector::const_iterator; Iterator begin() { return ops_.begin(); } Iterator end() { return ops_.end(); } ConstIterator begin() const { return ops_.begin(); } ConstIterator end() const { return ops_.end(); } void InsertGradPendingNode(const std::shared_ptr& node) { if (node && std::find(grad_pending_nodes_.begin(), grad_pending_nodes_.end(), node) == grad_pending_nodes_.end()) { grad_pending_nodes_.emplace_back(node); } } const std::vector>& GradPendingNodes() const { return grad_pending_nodes_; } private: DISABLE_COPY_AND_ASSIGN(GradOpNode); private: std::vector ops_; std::vector> grad_pending_nodes_; }; } // namespace imperative } // namespace paddle