op_registry.h 10.1 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

S
superjomn 已提交
15
#pragma once
S
superjomn 已提交
16
#include <list>
S
superjomn 已提交
17
#include <memory>
S
superjomn 已提交
18
#include <set>
S
superjomn 已提交
19 20
#include <string>
#include <unordered_map>
S
superjomn 已提交
21
#include <utility>
22
#include <vector>
C
Chunwei 已提交
23
#include "paddle/fluid/lite/api/paddle_lite_factory_helper.h"
S
superjomn 已提交
24 25 26
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
S
superjomn 已提交
27 28
#include "paddle/fluid/lite/utils/all.h"

S
Superjomn 已提交
29 30
using LiteType = paddle::lite::Type;

S
superjomn 已提交
31 32 33 34 35
namespace paddle {
namespace lite {

using KernelFunc = std::function<void()>;
using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>;
36
class LiteOpRegistry final : public Factory<OpLite, std::shared_ptr<OpLite>> {
S
superjomn 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
 public:
  static LiteOpRegistry &Global() {
    static auto *x = new LiteOpRegistry;
    return *x;
  }

 private:
  LiteOpRegistry() = default;
};

template <typename OpClass>
class OpLiteRegistor : public Registor<OpClass> {
 public:
S
superjomn 已提交
50
  explicit OpLiteRegistor(const std::string &op_type)
S
superjomn 已提交
51 52
      : Registor<OpClass>([&] {
          LiteOpRegistry::Global().Register(
S
superjomn 已提交
53 54
              op_type, [op_type]() -> std::unique_ptr<OpLite> {
                return std::unique_ptr<OpLite>(new OpClass(op_type));
S
superjomn 已提交
55 56 57 58
              });
        }) {}
};

S
superjomn 已提交
59
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
60
using KernelRegistryForTarget =
S
superjomn 已提交
61
    Factory<KernelLite<Target, Precision, Layout>, std::unique_ptr<KernelBase>>;
S
superjomn 已提交
62 63 64

class KernelRegistry final {
 public:
S
superjomn 已提交
65
  using any_kernel_registor_t =
S
superjomn 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78
      variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kHost), PRECISION(kAny),
                                      DATALAYOUT(kAny)> *,  //
              KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kAny),
79 80 81 82
                                      DATALAYOUT(kAny)> *,  //
              KernelRegistryForTarget<TARGET(kARM), PRECISION(kAny),
                                      DATALAYOUT(kAny)> *,  //
              KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat),
N
nhzlx 已提交
83 84
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kARM), PRECISION(kInt8),
C
Chunwei 已提交
85 86 87 88
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL), PRECISION(kFloat),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL), PRECISION(kInt8),
89
                                      DATALAYOUT(kNCHW)> *  //
S
superjomn 已提交
90
              >;
S
superjomn 已提交
91

S
superjomn 已提交
92
  KernelRegistry();
S
superjomn 已提交
93

S
superjomn 已提交
94
  static KernelRegistry &Global();
S
superjomn 已提交
95

S
superjomn 已提交
96
  template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
S
superjomn 已提交
97
  void Register(const std::string &name,
S
superjomn 已提交
98 99
                typename KernelRegistryForTarget<Target, Precision,
                                                 Layout>::creator_t &&creator) {
T
tensor-tang 已提交
100 101 102
    VLOG(3) << "register for " << TargetToStr(Target) << ":"
            << PrecisionToStr(Precision) << "//"
            << GetKernelOffset<Target, Precision, Layout>();
S
superjomn 已提交
103 104 105
    using kernel_registor_t =
        KernelRegistryForTarget<Target, Precision, Layout>;
    auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
106 107 108
    auto *reg = varient.template get<kernel_registor_t *>();
    CHECK(reg) << "Can not be empty of " << name;
    reg->Register(name, std::move(creator));
S
superjomn 已提交
109 110
  }

S
superjomn 已提交
111 112
  template <TargetType Target, PrecisionType Precision = PRECISION(kFloat),
            DataLayoutType Layout = DATALAYOUT(kNCHW)>
113
  std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
S
superjomn 已提交
114 115 116
    using kernel_registor_t =
        KernelRegistryForTarget<Target, Precision, Layout>;
    return registries_[GetKernelOffset<Target, Precision, Layout>()]
S
update  
superjomn 已提交
117
        .template get<kernel_registor_t *>()
118
        ->Creates(op_type);
S
update  
superjomn 已提交
119 120
  }

121 122
  std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
                                                TargetType target,
S
superjomn 已提交
123 124
                                                PrecisionType precision,
                                                DataLayoutType layout);
S
update  
superjomn 已提交
125

S
superjomn 已提交
126
  // Get a kernel registry offset in all the registries.
S
superjomn 已提交
127 128 129 130 131 132 133 134 135
  template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
  static int GetKernelOffset() {
    CHECK_LT(static_cast<int>(Target), static_cast<int>(TARGET(NUM)));
    CHECK_LT(static_cast<int>(Precision), static_cast<int>(PRECISION(NUM)));
    CHECK_LT(static_cast<int>(Layout), static_cast<int>(DATALAYOUT(NUM)));
    return static_cast<int>(Target) * static_cast<int>(PRECISION(NUM)) *
               static_cast<int>(DATALAYOUT(NUM)) +                            //
           static_cast<int>(Precision) * static_cast<int>(DATALAYOUT(NUM)) +  //
           static_cast<int>(Layout);
S
superjomn 已提交
136 137
  }

S
superjomn 已提交
138 139 140
  std::string DebugString() const {
    std::stringstream ss;
    ss << "KernelCreator<host, float>:" << std::endl;
141 142 143 144 145 146 147 148
    constexpr TargetType tgt = TARGET(kHost);
    constexpr PrecisionType dt = PRECISION(kFloat);
    constexpr DataLayoutType lt = DATALAYOUT(kNCHW);
    constexpr DataLayoutType kany = DATALAYOUT(kAny);
    using kernel_registor_t = KernelRegistryForTarget<tgt, dt, lt>;
    auto *reg = registries_[GetKernelOffset<tgt, dt, kany>()]
                    .template get<kernel_registor_t *>();
    ss << reg->DebugString() << std::endl;
S
superjomn 已提交
149 150 151
    return ss.str();
  }

S
superjomn 已提交
152
 private:
153
  mutable std::vector<any_kernel_registor_t> registries_;
S
superjomn 已提交
154 155
};

S
superjomn 已提交
156 157
template <TargetType target, PrecisionType precision, DataLayoutType layout,
          typename KernelType>
S
superjomn 已提交
158 159
class KernelRegistor : public lite::Registor<KernelType> {
 public:
S
superjomn 已提交
160 161
  KernelRegistor(const std::string &op_type, const std::string &alias)
      : Registor<KernelType>([=] {
T
tensor-tang 已提交
162 163 164
          VLOG(3) << "Register kernel " << op_type << " for "
                  << TargetToStr(target) << " " << PrecisionToStr(precision)
                  << " " << DataLayoutToStr(layout) << " alias " << alias;
S
superjomn 已提交
165 166
          KernelRegistry::Global().Register<target, precision, layout>(
              op_type, [=]() -> std::unique_ptr<KernelType> {
S
superjomn 已提交
167 168
                std::unique_ptr<KernelType> x(new KernelType);
                x->set_op_type(op_type);
S
superjomn 已提交
169
                x->set_alias(alias);
S
superjomn 已提交
170
                return x;
S
superjomn 已提交
171 172 173 174 175 176 177 178 179 180 181
              });
        }) {}
};

}  // namespace lite
}  // namespace paddle

// Operator registry
#define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__
#define REGISTER_LITE_OP(op_type__, OpClass)                              \
  static paddle::lite::OpLiteRegistor<OpClass> LITE_OP_REGISTER_INSTANCE( \
S
superjomn 已提交
182 183 184 185
      op_type__)(#op_type__);                                             \
  int touch_op_##op_type__() {                                            \
    return LITE_OP_REGISTER_INSTANCE(op_type__).Touch();                  \
  }
S
superjomn 已提交
186 187 188

// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
S
superjomn 已提交
189
  op_type__##__##target__##__##precision__##__registor__
190
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
S
superjomn 已提交
191
                                      layout__, alias__)                \
192 193 194 195
  op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
  LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)

S
superjomn 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__,      \
                             KernelClass, alias__)                            \
  static paddle::lite::KernelRegistor<TARGET(target__),                       \
                                      PRECISION(precision__),                 \
                                      DATALAYOUT(layout__), KernelClass>      \
      LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__,         \
                                    layout__, alias__)(#op_type__, #alias__); \
  static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__,   \
                                          layout__, alias__);                 \
  int touch_##op_type__##target__##precision__##layout__##alias__() {         \
    LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \
        .Touch();                                                             \
    return 0;                                                                 \
  }                                                                           \
  static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__,    \
                                         layout__, alias__)                   \
      __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
          TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>(    \
          #op_type__ "/" #alias__)

#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \
                             alias__)                                    \
  op_type__##target__##precision__##layout__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \
                                   alias__)                                    \
  op_type__##target__##precision__##layout__##alias__##param_register