diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index dc740da070bfb8b46292dec8b83c41a598b05c6e..b47e16a93e1aeb5a22de00ff7fa4097738fce385 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -148,6 +148,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, .substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); } +// SPMDRuleMap +SPMDRuleMap& SPMDRuleMap::Instance(); + static SPMDRuleMap g_spmd_rule_map; + return g_spmd_rule_map; +} + +// TODO enable default replicated spmd rule for op that are NOT registered +// which all tensors of inputs and outputs will be replicated in all ranks of the mesh. +SPMDRuleBase& SPMDRuleMap::Get(const std::string& op_type) const { + auto rule_ptr = GetNullable(op_type); + PADDLE_ENFORCE_NOT_NULL( + rule_ptr, + platform::errors::NotFound("NO SPMD Rule has been registered for Operator [%s].", op_type)); + return *rule_ptr; +} + +SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const { + auto it = map_.find(op_type); + if (it == map_.end()) { + return nullptr; + } else { + return it->second.get(); + } +} + +void SPMDRuleMap::Insert(const std::string& op_type, std::unique_ptr rule){ + PADDLE_ENFORCE_NE(Has(op_type), + true, + platform::errors::AlreadyExists( + "SPMD Rule for Operator [%s] has been registered.", type)); + map_.insert({op_type, std::move(rule)}); +} + + + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 41d031106a4c348c278c3589c68baf6dfc142df3..5de98ba51176585e8aeeea0617cf51f58f6e458d 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -109,6 +109,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, const int64_t& broadcast_ndim, const std::string& alphabet); +// The static map that stores and initializes all the registered SPMD rules. +class SPMDRuleMap { + public: + ~SPMDRuleMap() = default; + + // A singleton + static SPMDRuleMap& Instance(); + + // Returns the spmd rule for the given op_type + SPMDRuleBase& Get(const std::string& op_type) const; + + // Returns the spmd by name or nullptr if not registered + SPMDRuleBase* GetNullable(const std::string& op_type) const; + + // Register a spmd for an op_type. + void Insert(const std::string& op_type, std::unique_ptr rule); + + + bool Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); + } + + private: + SPMDRuleMap() = default; + paddle::flat_hash_map> map_; + DISABLE_COPY_AND_ASSIGN(SPMDRuleMap); + +}; + +#define REGISTER_SPMDRULE(op_type, rule_class, ...) \ + SPMDRuleMap::Instance().Insert( \ + op_type, \ + std::make_unique(__VA_ARGS__)) + + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index d819a8b7ae3d2ba588bb9df8fdfb41105a9d5900..0d3337ebbec9945f051685307f1b3a2bdd5303f6 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -204,6 +204,8 @@ std::vector MatmulSPMDRule::InferBackward( return {}; } +REGISTER_SPMDRULE(matmul, MatmulSPMDRule); + } // namespace auto_parallel } // namespace distributed } // namespace paddle