// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { namespace details { class FuseOptimizerOpPass : public ir::Pass { protected: void ApplyImpl(ir::Graph *graph) const override; protected: virtual void SortParametersAndAuxVars( const std::vector> ¶ms_grads, std::unordered_map> *aux_var_set, std::vector *ops) const; void InserInputAndOutputForOptOps(const std::vector &opt_ops, ir::Node *opt_node) const; private: virtual const std::string GetOpType() const = 0; virtual const std::vector GetAuxiliaryVarNames() const = 0; virtual void FuseOptimizerOps( const std::unordered_map> &vars_set, const std::unordered_map &fused_vars_name, const std::vector &adam_ops, ir::Graph *graph) const = 0; void GetSpecifiedOpsAndVars( const std::string &op_type, const std::vector &aux_vars_name, ir::Node *node, std::vector *ops, std::unordered_map> *aux_args_name) const; void AppendAllocContinuousSpace(const std::vector &args, const std::string &out_arg, bool copy_data, BlockDesc *global_block) const; void InitFusedVarsAndAllocSpaceForVars( const std::vector &places, const std::vector &local_scopes, const std::vector &aux_var_names, const std::unordered_map> &aux_var_set, const std::unordered_map &fused_vars_name) const; }; } // namespace details } // namespace framework } // namespace paddle