未验证 提交 0d10c60e 编写于 作者: H haozech 提交者: GitHub

[Framework][Op_regitry]reconstruct op_registry (#3699)

* reconstruct op_registry. test=develop

* remove function GetKernelOffset. test=develop
上级 975e8b8e
......@@ -165,9 +165,7 @@ std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
// 该文件第2处
// 找到文件中的下面的函数
KernelRegistry::KernelRegistry()
: registries_(static_cast<int>(TARGET(NUM)) *
static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM)))
: registries_() {
// 在该函数中加入新增Layout的下面内容
INIT_FOR(kOpenCL, kFP16, kNCHW);
......
......@@ -124,19 +124,16 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
return std::list<std::unique_ptr<KernelBase>>();
}
KernelRegistry::KernelRegistry()
: registries_(static_cast<int>(TARGET(NUM)) *
static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM))) {
#define INIT_FOR(target__, precision__, layout__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
KernelRegistry::KernelRegistry() : registries_() {
#define INIT_FOR(target__, precision__, layout__) \
registries_[std::make_tuple(TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__))] \
.set<KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kCUDA, kFloat, kNCHW);
......
......@@ -332,7 +332,7 @@ class KernelRegistry final {
&&creator) {
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
auto &varient = registries_[std::make_tuple(Target, Precision, Layout)];
auto *reg = varient.template get<kernel_registor_t *>();
CHECK(reg) << "Can not be empty of " << name;
reg->Register(name, std::move(creator));
......@@ -349,10 +349,12 @@ class KernelRegistry final {
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
std::list<std::unique_ptr<KernelBase>> kernel_list;
if (registries_[GetKernelOffset<Target, Precision, Layout>()].valid()) {
kernel_list = registries_[GetKernelOffset<Target, Precision, Layout>()]
.template get<kernel_registor_t *>()
->Creates(op_type);
std::tuple<TargetType, PrecisionType, DataLayoutType> temp_tuple(
Target, Precision, Layout);
if (registries_[temp_tuple].valid()) {
kernel_list =
registries_[temp_tuple].template get<kernel_registor_t *>()->Creates(
op_type);
}
return kernel_list;
}
......@@ -362,18 +364,6 @@ class KernelRegistry final {
PrecisionType precision,
DataLayoutType layout);
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
static int GetKernelOffset() {
CHECK_LT(static_cast<int>(Target), static_cast<int>(TARGET(NUM)));
CHECK_LT(static_cast<int>(Precision), static_cast<int>(PRECISION(NUM)));
CHECK_LT(static_cast<int>(Layout), static_cast<int>(DATALAYOUT(NUM)));
return static_cast<int>(Target) * static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Precision) * static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Layout);
}
std::string DebugString() const {
#ifndef LITE_ON_MODEL_OPTIMIZE_TOOL
return "No more debug info";
......@@ -404,7 +394,9 @@ class KernelRegistry final {
}
private:
mutable std::vector<any_kernel_registor_t> registries_;
mutable std::map<std::tuple<TargetType, PrecisionType, DataLayoutType>,
any_kernel_registor_t>
registries_;
#ifndef LITE_ON_TINY_PUBLISH
mutable std::map<
std::string,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册