提交 747de082 编写于 作者: J JZ-LIANG

add registry

上级 c1545a4b
...@@ -148,6 +148,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, ...@@ -148,6 +148,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); .substr(broadcast_ndim - tenosr_ndim, tenosr_ndim);
} }
// SPMDRuleMap
SPMDRuleMap& SPMDRuleMap::Instance();
static SPMDRuleMap g_spmd_rule_map;
return g_spmd_rule_map;
}
// TODO enable default replicated spmd rule for op that are NOT registered
// which all tensors of inputs and outputs will be replicated in all ranks of the mesh.
SPMDRuleBase& SPMDRuleMap::Get(const std::string& op_type) const {
auto rule_ptr = GetNullable(op_type);
PADDLE_ENFORCE_NOT_NULL(
rule_ptr,
platform::errors::NotFound("NO SPMD Rule has been registered for Operator [%s].", op_type));
return *rule_ptr;
}
SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it == map_.end()) {
return nullptr;
} else {
return it->second.get();
}
}
void SPMDRuleMap::Insert(const std::string& op_type, std::unique_ptr<SPMDRuleBase> rule){
PADDLE_ENFORCE_NE(Has(op_type),
true,
platform::errors::AlreadyExists(
"SPMD Rule for Operator [%s] has been registered.", type));
map_.insert({op_type, std::move(rule)});
}
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -109,6 +109,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, ...@@ -109,6 +109,41 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim, const int64_t& broadcast_ndim,
const std::string& alphabet); const std::string& alphabet);
// The static map that stores and initializes all the registered SPMD rules.
class SPMDRuleMap {
public:
~SPMDRuleMap() = default;
// A singleton
static SPMDRuleMap& Instance();
// Returns the spmd rule for the given op_type
SPMDRuleBase& Get(const std::string& op_type) const;
// Returns the spmd by name or nullptr if not registered
SPMDRuleBase* GetNullable(const std::string& op_type) const;
// Register a spmd for an op_type.
void Insert(const std::string& op_type, std::unique_ptr<SPMDRuleBase> rule);
bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
private:
SPMDRuleMap() = default;
paddle::flat_hash_map<std::string, std::unique_ptr<SPMDRuleBase>> map_;
DISABLE_COPY_AND_ASSIGN(SPMDRuleMap);
};
#define REGISTER_SPMDRULE(op_type, rule_class, ...) \
SPMDRuleMap::Instance().Insert( \
op_type, \
std::make_unique<rule_class>(__VA_ARGS__))
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -204,6 +204,8 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferBackward( ...@@ -204,6 +204,8 @@ std::vector<TensorDistAttr> MatmulSPMDRule::InferBackward(
return {}; return {};
} }
REGISTER_SPMDRULE(matmul, MatmulSPMDRule);
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册