From 747de082e0fc99bc78236355c8acc7f7067b0bda Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 6 Jun 2023 15:26:27 +0800 Subject: [PATCH] add registry --- .../auto_parallel/spmd_rules/common.cc | 35 +++++++++++++++++++ .../auto_parallel/spmd_rules/common.h | 35 +++++++++++++++++++ .../spmd_rules/matmul_spmd_rule.cc | 2 ++ 3 files changed, 72 insertions(+) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index dc740da070b..b47e16a93e1 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -148,6 +148,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, .substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); } +// SPMDRuleMap +SPMDRuleMap& SPMDRuleMap::Instance(); + static SPMDRuleMap g_spmd_rule_map; + return g_spmd_rule_map; +} + +// TODO enable default replicated spmd rule for op that are NOT registered +// which all tensors of inputs and outputs will be replicated in all ranks of the mesh. +SPMDRuleBase& SPMDRuleMap::Get(const std::string& op_type) const { + auto rule_ptr = GetNullable(op_type); + PADDLE_ENFORCE_NOT_NULL( + rule_ptr, + platform::errors::NotFound("NO SPMD Rule has been registered for Operator [%s].", op_type)); + return *rule_ptr; +} + +SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const { + auto it = map_.find(op_type); + if (it == map_.end()) { + return nullptr; + } else { + return it->second.get(); + } +} + +void SPMDRuleMap::Insert(const std::string& op_type, std::unique_ptr rule){ + PADDLE_ENFORCE_NE(Has(op_type), + true, + platform::errors::AlreadyExists( + "SPMD Rule for Operator [%s] has been registered.", type)); + map_.insert({op_type, std::move(rule)}); +} + + + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 41d031106a4..5de98ba5117 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -109,6 +109,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, const int64_t& broadcast_ndim, const std::string& alphabet); +// The static map that stores and initializes all the registered SPMD rules. +class SPMDRuleMap { + public: + ~SPMDRuleMap() = default; + + // A singleton + static SPMDRuleMap& Instance(); + + // Returns the spmd rule for the given op_type + SPMDRuleBase& Get(const std::string& op_type) const; + + // Returns the spmd by name or nullptr if not registered + SPMDRuleBase* GetNullable(const std::string& op_type) const; + + // Register a spmd for an op_type. + void Insert(const std::string& op_type, std::unique_ptr rule); + + + bool Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); + } + + private: + SPMDRuleMap() = default; + paddle::flat_hash_map> map_; + DISABLE_COPY_AND_ASSIGN(SPMDRuleMap); + +}; + +#define REGISTER_SPMDRULE(op_type, rule_class, ...) \ + SPMDRuleMap::Instance().Insert( \ + op_type, \ + std::make_unique(__VA_ARGS__)) + + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc index d819a8b7ae3..0d3337ebbec 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc @@ -204,6 +204,8 @@ std::vector MatmulSPMDRule::InferBackward( return {}; } +REGISTER_SPMDRULE(matmul, MatmulSPMDRule); + } // namespace auto_parallel } // namespace distributed } // namespace paddle -- GitLab