registry.h 8.4 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#pragma once

#include <memory>
#include <tuple>
#include <type_traits>
20
#include <utility>  // for std::move
21

T
tensor-tang 已提交
22 23
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
T
tensor-tang 已提交
24
#include "paddle/fluid/platform/place.h"
T
tensor-tang 已提交
25
#include "paddle/fluid/platform/variant.h"  // for UNUSED
T
tensor-tang 已提交
26 27 28

namespace paddle {
namespace operators {
T
tensor-tang 已提交
29
namespace jit {
T
tensor-tang 已提交
30

T
tensor-tang 已提交
31
// make_unique is supported since c++14
T
tensor-tang 已提交
32 33 34 35 36 37
template <typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
  static_assert(!std::is_array<T>::value, "T must not be array");
  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

38 39 40 41
template <typename Pool,
          typename PlaceType,
          bool IsEnd,
          size_t I,
T
tensor-tang 已提交
42
          typename... KernelImpls>
T
tensor-tang 已提交
43 44
struct JitKernelRegistrarFunctor;

T
tensor-tang 已提交
45 46
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, true, I, KernelImpls...> {
T
tensor-tang 已提交
47 48 49
  void operator()(KernelType kt) const {}
};

T
tensor-tang 已提交
50 51
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
T
tensor-tang 已提交
52 53 54 55 56
  using KERNEL_IMPL_TYPE =
      typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;

  void operator()(KernelType kt) const {
    KernelKey kkey(kt, PlaceType());
57 58
    Pool::Instance().Insert(kkey,
                            std::move(make_unique<const KERNEL_IMPL_TYPE>()));
T
tensor-tang 已提交
59
    constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
60 61 62 63
    JitKernelRegistrarFunctor<Pool,
                              PlaceType,
                              I + 1 == size,
                              I + 1,
T
tensor-tang 已提交
64
                              KernelImpls...>
T
tensor-tang 已提交
65 66 67 68 69
        func;
    func(kt);
  }
};

T
tensor-tang 已提交
70
template <typename Pool, typename PlaceType, typename... KernelImpls>
T
tensor-tang 已提交
71 72 73
class JitKernelRegistrar {
 public:
  explicit JitKernelRegistrar(KernelType kt) {
T
tensor-tang 已提交
74
    JitKernelRegistrarFunctor<Pool, PlaceType, false, 0, KernelImpls...> func;
T
tensor-tang 已提交
75 76
    func(kt);
  }
T
tensor-tang 已提交
77
  void Touch() {}
T
tensor-tang 已提交
78 79 80 81 82 83 84 85
};

#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg)              \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

T
tensor-tang 已提交
86
// Refer always on CPUPlace
87 88 89 90 91 92 93 94 95 96 97 98 99
#define REGISTER_JITKERNEL_REFER(kernel_type, ...)                  \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_##kernel_type##_refer_CPUPlace,               \
      "REGISTER_KERNEL_REFER must be called in global namespace");  \
  static ::paddle::operators::jit::JitKernelRegistrar<              \
      ::paddle::operators::jit::ReferKernelPool,                    \
      ::paddle::platform::CPUPlace,                                 \
      __VA_ARGS__>                                                  \
      __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_(       \
          ::paddle::operators::jit::KernelType::kernel_type);       \
  int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() {         \
    __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
    return 0;                                                       \
T
tensor-tang 已提交
100 101
  }

T
tensor-tang 已提交
102
// kernel_type: should be in paddle::operators::jit::KernelType
T
tensor-tang 已提交
103
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
T
tensor-tang 已提交
104 105 106 107 108 109 110
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...)         \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                                   \
      __reg_jitkernel_##kernel_type##_##impl_type##_##place_type,             \
      "REGISTER_KERNEL_MORE must be called in global namespace");             \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();             \
  static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
      UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();           \
T
tensor-tang 已提交
111
  static ::paddle::operators::jit::JitKernelRegistrar<                        \
112 113
      ::paddle::operators::jit::KernelPool,                                   \
      ::paddle::platform::place_type,                                         \
T
tensor-tang 已提交
114
      __VA_ARGS__>                                                            \
T
tensor-tang 已提交
115
      __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_(   \
T
tensor-tang 已提交
116
          ::paddle::operators::jit::KernelType::kernel_type);                 \
T
tensor-tang 已提交
117 118 119 120 121
  int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() {     \
    __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_      \
        .Touch();                                                             \
    return 0;                                                                 \
  }
T
tensor-tang 已提交
122 123 124 125 126 127 128

#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
  REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)

#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
  REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)

T
tensor-tang 已提交
129 130 131 132 133 134 135 136 137
#define REGISTER_JITKERNEL_GEN(kernel_type, ...)                    \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_gen_##kernel_type##_CPUPlace_,                \
      "REGISTER_JITKERNEL_GEN must be called in global namespace"); \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();   \
  static int __assert_gen_##kernel_type##_has_refer_ UNUSED =       \
      TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();          \
  static ::paddle::operators::jit::JitKernelRegistrar<              \
      ::paddle::operators::jit::JitCodeCreatorPool,                 \
138 139
      ::paddle::platform::CPUPlace,                                 \
      __VA_ARGS__>                                                  \
T
tensor-tang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153
      __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_(         \
          ::paddle::operators::jit::KernelType::kernel_type);       \
  int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() {           \
    __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch();   \
    return 0;                                                       \
  }

#define USE_JITKERNEL_GEN(kernel_type)                            \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                       \
      __reg_jitkernel_gen_##kernel_type##_CPUPlace_,              \
      "USE_JITKERNEL_GEN must be called in global namespace");    \
  extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_();   \
  static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
      TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
T
tensor-tang 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

#define USE_JITKERNEL_REFER(kernel_type)                            \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_##kernel_type##_refer_CPUPlace_,              \
      "USE_JITKERNEL_REFER must be called in global namespace");    \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();   \
  static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
      TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()

#define USE_KERNEL_MORE(kernel_type, impl_type, place_type)              \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                              \
      __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_,     \
      "USE_JITKERNEL_MORE must be called in global namespace");          \
  extern int                                                             \
      TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
  static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
      UNUSED =                                                           \
          TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()

#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
  USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
T
tensor-tang 已提交
175

T
tensor-tang 已提交
176
}  // namespace jit
T
tensor-tang 已提交
177 178
}  // namespace operators
}  // namespace paddle