未验证 提交 0d10c60e 编写于 作者: H haozech 提交者: GitHub

[Framework][Op_regitry]reconstruct op_registry (#3699)

* reconstruct op_registry. test=develop

* remove function GetKernelOffset. test=develop
上级 975e8b8e
...@@ -165,9 +165,7 @@ std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) { ...@@ -165,9 +165,7 @@ std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
// 该文件第2处 // 该文件第2处
// 找到文件中的下面的函数 // 找到文件中的下面的函数
KernelRegistry::KernelRegistry() KernelRegistry::KernelRegistry()
: registries_(static_cast<int>(TARGET(NUM)) * : registries_() {
static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM)))
// 在该函数中加入新增Layout的下面内容 // 在该函数中加入新增Layout的下面内容
INIT_FOR(kOpenCL, kFP16, kNCHW); INIT_FOR(kOpenCL, kFP16, kNCHW);
......
...@@ -124,14 +124,11 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -124,14 +124,11 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
return std::list<std::unique_ptr<KernelBase>>(); return std::list<std::unique_ptr<KernelBase>>();
} }
KernelRegistry::KernelRegistry() KernelRegistry::KernelRegistry() : registries_() {
: registries_(static_cast<int>(TARGET(NUM)) *
static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM))) {
#define INIT_FOR(target__, precision__, layout__) \ #define INIT_FOR(target__, precision__, layout__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \ registries_[std::make_tuple(TARGET(target__), \
PRECISION(precision__), \ PRECISION(precision__), \
DATALAYOUT(layout__)>()] \ DATALAYOUT(layout__))] \
.set<KernelRegistryForTarget<TARGET(target__), \ .set<KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \ PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \ DATALAYOUT(layout__)> *>( \
......
...@@ -332,7 +332,7 @@ class KernelRegistry final { ...@@ -332,7 +332,7 @@ class KernelRegistry final {
&&creator) { &&creator) {
using kernel_registor_t = using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>; KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()]; auto &varient = registries_[std::make_tuple(Target, Precision, Layout)];
auto *reg = varient.template get<kernel_registor_t *>(); auto *reg = varient.template get<kernel_registor_t *>();
CHECK(reg) << "Can not be empty of " << name; CHECK(reg) << "Can not be empty of " << name;
reg->Register(name, std::move(creator)); reg->Register(name, std::move(creator));
...@@ -349,10 +349,12 @@ class KernelRegistry final { ...@@ -349,10 +349,12 @@ class KernelRegistry final {
using kernel_registor_t = using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>; KernelRegistryForTarget<Target, Precision, Layout>;
std::list<std::unique_ptr<KernelBase>> kernel_list; std::list<std::unique_ptr<KernelBase>> kernel_list;
if (registries_[GetKernelOffset<Target, Precision, Layout>()].valid()) { std::tuple<TargetType, PrecisionType, DataLayoutType> temp_tuple(
kernel_list = registries_[GetKernelOffset<Target, Precision, Layout>()] Target, Precision, Layout);
.template get<kernel_registor_t *>() if (registries_[temp_tuple].valid()) {
->Creates(op_type); kernel_list =
registries_[temp_tuple].template get<kernel_registor_t *>()->Creates(
op_type);
} }
return kernel_list; return kernel_list;
} }
...@@ -362,18 +364,6 @@ class KernelRegistry final { ...@@ -362,18 +364,6 @@ class KernelRegistry final {
PrecisionType precision, PrecisionType precision,
DataLayoutType layout); DataLayoutType layout);
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
static int GetKernelOffset() {
CHECK_LT(static_cast<int>(Target), static_cast<int>(TARGET(NUM)));
CHECK_LT(static_cast<int>(Precision), static_cast<int>(PRECISION(NUM)));
CHECK_LT(static_cast<int>(Layout), static_cast<int>(DATALAYOUT(NUM)));
return static_cast<int>(Target) * static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Precision) * static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Layout);
}
std::string DebugString() const { std::string DebugString() const {
#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL #ifndef LITE_ON_MODEL_OPTIMIZE_TOOL
return "No more debug info"; return "No more debug info";
...@@ -404,7 +394,9 @@ class KernelRegistry final { ...@@ -404,7 +394,9 @@ class KernelRegistry final {
} }
private: private:
mutable std::vector<any_kernel_registor_t> registries_; mutable std::map<std::tuple<TargetType, PrecisionType, DataLayoutType>,
any_kernel_registor_t>
registries_;
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
mutable std::map< mutable std::map<
std::string, std::string,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册