op_registry.h 7.2 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

S
superjomn 已提交
15
#pragma once
S
superjomn 已提交
16 17 18
#include <memory>
#include <string>
#include <unordered_map>
S
superjomn 已提交
19 20 21
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
S
superjomn 已提交
22 23 24 25 26 27 28 29
#include "paddle/fluid/lite/utils/all.h"

namespace paddle {
namespace lite {

using KernelFunc = std::function<void()>;
using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>;

30
class LiteOpRegistry final : public Factory<OpLite, std::shared_ptr<OpLite>> {
S
superjomn 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
 public:
  static LiteOpRegistry &Global() {
    static auto *x = new LiteOpRegistry;
    return *x;
  }

 private:
  LiteOpRegistry() = default;
};

template <typename OpClass>
class OpLiteRegistor : public Registor<OpClass> {
 public:
  OpLiteRegistor(const std::string &op_type)
      : Registor<OpClass>([&] {
          LiteOpRegistry::Global().Register(
S
superjomn 已提交
47 48
              op_type, [op_type]() -> std::unique_ptr<OpLite> {
                return std::unique_ptr<OpLite>(new OpClass(op_type));
S
superjomn 已提交
49 50 51 52 53
              });
        }) {}
};

template <TargetType Target, PrecisionType Precision>
54 55 56
using KernelRegistryForTarget =
    Factory<OpKernel<Target, Precision>,
            std::unique_ptr<OpKernel<Target, Precision>>>;
S
superjomn 已提交
57 58 59

class KernelRegistry final {
 public:
S
superjomn 已提交
60 61 62 63 64 65 66
  using any_kernel_registor_t =
      variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)> *,  //
              KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8)> *,   //
              KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat)> *,   //
              KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8)> *,    //
              KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> *   //
              >;
S
superjomn 已提交
67

S
superjomn 已提交
68
  KernelRegistry();
S
superjomn 已提交
69

S
superjomn 已提交
70
  static KernelRegistry &Global();
S
superjomn 已提交
71 72 73 74 75 76

  template <TargetType Target, PrecisionType Precision>
  void Register(const std::string &name,
                typename KernelRegistryForTarget<Target, Precision>::creator_t
                    &&creator) {
    using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
S
superjomn 已提交
77 78
    registries_[GetKernelOffset<Target, Precision>()]
        .template get<kernel_registor_t *>()
S
superjomn 已提交
79 80 81
        ->Register(name, std::move(creator));
  }

S
update  
superjomn 已提交
82 83 84 85 86 87 88 89 90 91
  template <TargetType Target, PrecisionType Precision>
  std::unique_ptr<KernelBase> Create(const std::string &op_type) {
    using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
    return registries_[GetKernelOffset<Target, Precision>()]
        .template get<kernel_registor_t *>()
        ->Create(op_type);
  }

  std::unique_ptr<KernelBase> Create(const std::string &op_type,
                                     TargetType target,
S
superjomn 已提交
92
                                     PrecisionType precision);
S
update  
superjomn 已提交
93

S
superjomn 已提交
94 95 96 97 98 99
  // Get a kernel registry offset in all the registries.
  template <TargetType Target, PrecisionType Precision>
  static constexpr int GetKernelOffset() {
    return kNumTargets * static_cast<int>(Target) + static_cast<int>(Precision);
  }

S
superjomn 已提交
100 101 102 103 104 105 106 107 108 109 110 111
  std::string DebugString() const {
    std::stringstream ss;

    ss << "KernelCreator<host, float>:" << std::endl;
    ss << registries_[GetKernelOffset<TARGET(kHost), PRECISION(kFloat)>()]
              .get<
                  KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> *>()
              ->DebugString();
    ss << std::endl;
    return ss.str();
  }

S
superjomn 已提交
112
 private:
S
superjomn 已提交
113 114
  mutable std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions>
      registries_;
S
superjomn 已提交
115 116 117 118 119 120 121
};

template <TargetType target, PrecisionType precision, typename KernelType>
class KernelRegistor : public lite::Registor<KernelType> {
 public:
  KernelRegistor(const std::string op_type)
      : Registor<KernelType>([&] {
S
superjomn 已提交
122 123
          LOG(INFO) << "Register kernel " << op_type << " for "
                    << TargetToStr(target) << " " << PrecisionToStr(precision);
S
superjomn 已提交
124
          KernelRegistry::Global().Register<target, precision>(
S
superjomn 已提交
125 126 127 128
              op_type, [&, op_type]() -> std::unique_ptr<KernelType> {
                std::unique_ptr<KernelType> x(new KernelType);
                x->set_op_type(op_type);
                return x;
S
superjomn 已提交
129 130 131 132 133 134 135 136 137 138 139 140
              });
        }) {}
};

}  // namespace lite
}  // namespace paddle

// Operator registry
#define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__
#define LITE_OP_REGISTER_FAKE(op_type__) op_type__##__registry__
#define REGISTER_LITE_OP(op_type__, OpClass)                              \
  static paddle::lite::OpLiteRegistor<OpClass> LITE_OP_REGISTER_INSTANCE( \
S
superjomn 已提交
141 142 143 144
      op_type__)(#op_type__);                                             \
  int touch_op_##op_type__() {                                            \
    return LITE_OP_REGISTER_INSTANCE(op_type__).Touch();                  \
  }
S
superjomn 已提交
145

S
superjomn 已提交
146 147 148 149
#define USE_LITE_OP(op_type__)                                   \
  extern int touch_op_##op_type__();                             \
  int LITE_OP_REGISTER_FAKE(op_type__) __attribute__((unused)) = \
      touch_op_##op_type__();
S
superjomn 已提交
150 151 152

// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
S
superjomn 已提交
153
  op_type__##__##target__##__##precision__##__registor__
S
superjomn 已提交
154
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \
S
superjomn 已提交
155
  op_type__##__##target__##__##precision__##__registor__instance__
S
superjomn 已提交
156
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
S
superjomn 已提交
157 158
  LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)

S
superjomn 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass)   \
  static paddle::lite::KernelRegistor<TARGET(target__),                       \
                                      PRECISION(precision__), KernelClass>    \
      LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__,                      \
                                    precision__)(#op_type__);                 \
  static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__);  \
  int touch_##op_type__##target__##precision__() {                            \
    LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch();           \
    return 0;                                                                 \
  }                                                                           \
  static bool op_type__##target__##precision__##param_register                \
      __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
          TARGET(target__), PRECISION(precision__)>(#op_type__)
S
superjomn 已提交
172

S
superjomn 已提交
173 174 175 176
#define USE_LITE_KERNEL(op_type__, target__, precision__)         \
  extern int touch_##op_type__##target__##precision__();          \
  int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
      __attribute__((unused)) = touch_##op_type__##target__##precision__();
S
superjomn 已提交
177

S
superjomn 已提交
178 179
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \
  op_type__##target__##precision__