registry.h 8.2 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#pragma once

#include <memory>
#include <tuple>
#include <type_traits>
20
#include <utility>  // for std::move
21

T
tensor-tang 已提交
22 23
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
T
tensor-tang 已提交
24
#include "paddle/fluid/platform/place.h"
T
tensor-tang 已提交
25
#include "paddle/fluid/platform/variant.h"  // for UNUSED
T
tensor-tang 已提交
26 27 28

namespace paddle {
namespace operators {
T
tensor-tang 已提交
29
namespace jit {
T
tensor-tang 已提交
30

T
tensor-tang 已提交
31
// make_unique is supported since c++14
T
tensor-tang 已提交
32 33 34 35 36 37
template <typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
  static_assert(!std::is_array<T>::value, "T must not be array");
  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

T
tensor-tang 已提交
38 39
template <typename Pool, typename PlaceType, bool IsEnd, size_t I,
          typename... KernelImpls>
T
tensor-tang 已提交
40 41
struct JitKernelRegistrarFunctor;

T
tensor-tang 已提交
42 43
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, true, I, KernelImpls...> {
T
tensor-tang 已提交
44 45 46
  void operator()(KernelType kt) const {}
};

T
tensor-tang 已提交
47 48
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
T
tensor-tang 已提交
49 50 51 52 53
  using KERNEL_IMPL_TYPE =
      typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;

  void operator()(KernelType kt) const {
    KernelKey kkey(kt, PlaceType());
54 55
    Pool::Instance().Insert(kkey,
                            std::move(make_unique<const KERNEL_IMPL_TYPE>()));
T
tensor-tang 已提交
56
    constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
T
tensor-tang 已提交
57 58
    JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
                              KernelImpls...>
T
tensor-tang 已提交
59 60 61 62 63
        func;
    func(kt);
  }
};

T
tensor-tang 已提交
64
template <typename Pool, typename PlaceType, typename... KernelImpls>
T
tensor-tang 已提交
65 66 67
class JitKernelRegistrar {
 public:
  explicit JitKernelRegistrar(KernelType kt) {
T
tensor-tang 已提交
68
    JitKernelRegistrarFunctor<Pool, PlaceType, false, 0, KernelImpls...> func;
T
tensor-tang 已提交
69 70
    func(kt);
  }
T
tensor-tang 已提交
71
  void Touch() {}
T
tensor-tang 已提交
72 73 74 75 76 77 78 79
};

#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg)              \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

T
tensor-tang 已提交
80
// Refer always on CPUPlace
T
tensor-tang 已提交
81 82 83 84 85 86 87 88 89 90 91 92
#define REGISTER_JITKERNEL_REFER(kernel_type, ...)                             \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                                    \
      __reg_jitkernel_##kernel_type##_refer_CPUPlace,                          \
      "REGISTER_KERNEL_REFER must be called in global namespace");             \
  static ::paddle::operators::jit::JitKernelRegistrar<                         \
      ::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \
      __VA_ARGS__>                                                             \
      __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_(                  \
          ::paddle::operators::jit::KernelType::kernel_type);                  \
  int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() {                    \
    __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch();            \
    return 0;                                                                  \
T
tensor-tang 已提交
93 94
  }

T
tensor-tang 已提交
95
// kernel_type: should be in paddle::operators::jit::KernelType
T
tensor-tang 已提交
96
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
T
tensor-tang 已提交
97 98 99 100 101 102 103
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...)         \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                                   \
      __reg_jitkernel_##kernel_type##_##impl_type##_##place_type,             \
      "REGISTER_KERNEL_MORE must be called in global namespace");             \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();             \
  static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
      UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();           \
T
tensor-tang 已提交
104 105 106
  static ::paddle::operators::jit::JitKernelRegistrar<                        \
      ::paddle::operators::jit::KernelPool, ::paddle::platform::place_type,   \
      __VA_ARGS__>                                                            \
T
tensor-tang 已提交
107
      __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_(   \
T
tensor-tang 已提交
108
          ::paddle::operators::jit::KernelType::kernel_type);                 \
T
tensor-tang 已提交
109 110 111 112 113
  int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() {     \
    __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_      \
        .Touch();                                                             \
    return 0;                                                                 \
  }
T
tensor-tang 已提交
114 115 116 117 118 119 120

#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
  REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)

#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
  REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)

T
tensor-tang 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
#define REGISTER_JITKERNEL_GEN(kernel_type, ...)                    \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_gen_##kernel_type##_CPUPlace_,                \
      "REGISTER_JITKERNEL_GEN must be called in global namespace"); \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();   \
  static int __assert_gen_##kernel_type##_has_refer_ UNUSED =       \
      TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();          \
  static ::paddle::operators::jit::JitKernelRegistrar<              \
      ::paddle::operators::jit::JitCodeCreatorPool,                 \
      ::paddle::platform::CPUPlace, __VA_ARGS__>                    \
      __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_(         \
          ::paddle::operators::jit::KernelType::kernel_type);       \
  int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() {           \
    __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch();   \
    return 0;                                                       \
  }

#define USE_JITKERNEL_GEN(kernel_type)                            \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                       \
      __reg_jitkernel_gen_##kernel_type##_CPUPlace_,              \
      "USE_JITKERNEL_GEN must be called in global namespace");    \
  extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_();   \
  static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
      TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
T
tensor-tang 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165

#define USE_JITKERNEL_REFER(kernel_type)                            \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_##kernel_type##_refer_CPUPlace_,              \
      "USE_JITKERNEL_REFER must be called in global namespace");    \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();   \
  static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
      TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()

#define USE_KERNEL_MORE(kernel_type, impl_type, place_type)              \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                              \
      __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_,     \
      "USE_JITKERNEL_MORE must be called in global namespace");          \
  extern int                                                             \
      TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
  static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
      UNUSED =                                                           \
          TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()

#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
  USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
T
tensor-tang 已提交
166

T
tensor-tang 已提交
167
}  // namespace jit
T
tensor-tang 已提交
168 169
}  // namespace operators
}  // namespace paddle