// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace platform { class NCCLContextMap; } namespace framework { class Scope; namespace ir { constexpr char kLossVarName[] = "loss_var_name"; constexpr char kStrategy[] = "strategy"; constexpr char kNRanks[] = "nranks"; class MultiDevSSAGraphBuilderBase : public ir::Pass { protected: void ApplyImpl(ir::Graph *graph) const override; virtual void Init() const; virtual void CheckGraph(const ir::Graph &graph) const; virtual std::vector SortOperations(const ir::Graph &graph) const; virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const = 0; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; virtual void InsertPostprocessOps(ir::Graph *result) const = 0; bool UseGPU() const; virtual bool NeedCollectiveForGrad(const std::string &grad_name, std::vector ops) const; bool IsScaleLossOp(ir::Node *node) const; void CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const; void CreateScaleLossGradOp(ir::Graph *result, const std::string &loss_grad_name, ir::Node *out_var_node, size_t loss_scale, proto::VarType::Type dtype) const; details::VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, size_t dst_dev_id) const; void CreateComputationalOp(ir::Graph *result, ir::Node *node, size_t dev_id) const; bool IsSparseGradient(const std::string &og) const; void CreateAllReduceOp(ir::Graph *result, const std::string &og, bool is_encoded = false) const; void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; void InsertScaleLossGradOp(ir::Graph *result, const ir::Node *node) const; void CreateFusedBroadcastOp( ir::Graph *result, const std::vector> &bcast_varnames) const; void SetCommunicationContext(details::OpHandleBase *op_handle, const platform::Place &p) const; void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t device_id) const; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) mutable platform::NCCLContextMap *nccl_ctxs_{nullptr}; mutable platform::MultiNCCLContextMap *multi_nccl_ctxs_{nullptr}; #endif mutable std::string loss_var_name_; mutable std::vector places_; mutable std::vector local_scopes_; mutable details::BuildStrategy strategy_; mutable std::unordered_map all_vars_; }; class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { protected: virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; virtual void InsertPostprocessOps(ir::Graph *result) const {} bool IsEncoded(const std::string &p_name) const; }; class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { protected: void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const override {} bool NeedCollectiveForGrad(const std::string &grad_name, std::vector ops) const { return false; } bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { if (node->Op()->Type() == "recv") { VLOG(1) << "set recv op do_not_run to true"; node->Op()->SetAttr("do_not_run", 1); node->Op()->Flush(); } else if (node->Name() == "lookup_table" || node->Name() == "nce" || node->Name() == "hierarchical_sigmoid") { // in async_mode, we do not need remote prefetch, because communicator // will do async parameter recv. VLOG(1) << "set " << node->Name() << " op remote_prefetch to false"; node->Op()->SetAttr("remote_prefetch", false); node->Op()->Flush(); } return false; } void InsertPostprocessOps(ir::Graph *result) const override {} }; class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { protected: int GetVarDeviceID(const std::string &varname) const; int GetOpDeviceID(ir::Node *node) const; size_t GetAppropriateDeviceID( const std::vector &var_names) const; virtual void ResetState() const; mutable std::unordered_map sharded_var_device_; mutable std::vector balance_vars_; }; class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder { protected: virtual void Init() const; virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; virtual void InsertPostprocessOps(ir::Graph *result) const; virtual std::vector SortOperations(const ir::Graph &graph) const; virtual void ResetState() const; int GetOpDeviceID(ir::Node *node, std::unordered_map> *delay_ops) const; std::vector SortForReduceMode( const std::vector &topo_ops) const; mutable std::vector> bcast_var_name_set_; }; class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { protected: virtual void Init() const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; virtual void InsertPostprocessOps(ir::Graph *result) const; virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; virtual void ResetState() const; int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; mutable std::vector> bcast_var_name_set_; mutable bool need_broadcast_var_{false}; }; std::unordered_set &MultiDevSSAGraphBuilder(); } // namespace ir } // namespace framework } // namespace paddle