registry.h 8.2 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#pragma once

#include <memory>
#include <tuple>
#include <type_traits>
T
tensor-tang 已提交
20 21
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
T
tensor-tang 已提交
22
#include "paddle/fluid/platform/place.h"
T
tensor-tang 已提交
23
#include "paddle/fluid/platform/variant.h"  // for UNUSED
T
tensor-tang 已提交
24 25 26

namespace paddle {
namespace operators {
T
tensor-tang 已提交
27
namespace jit {
T
tensor-tang 已提交
28

T
tensor-tang 已提交
29
// make_unique is supported since c++14
T
tensor-tang 已提交
30 31 32 33 34 35
template <typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
  static_assert(!std::is_array<T>::value, "T must not be array");
  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

T
tensor-tang 已提交
36 37
template <typename Pool, typename PlaceType, bool IsEnd, size_t I,
          typename... KernelImpls>
T
tensor-tang 已提交
38 39
struct JitKernelRegistrarFunctor;

T
tensor-tang 已提交
40 41
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, true, I, KernelImpls...> {
T
tensor-tang 已提交
42 43 44
  void operator()(KernelType kt) const {}
};

T
tensor-tang 已提交
45 46
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
T
tensor-tang 已提交
47 48 49 50 51
  using KERNEL_IMPL_TYPE =
      typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;

  void operator()(KernelType kt) const {
    KernelKey kkey(kt, PlaceType());
52 53
    Pool::Instance().Insert(kkey,
                            std::move(make_unique<const KERNEL_IMPL_TYPE>()));
T
tensor-tang 已提交
54
    constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
T
tensor-tang 已提交
55 56
    JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
                              KernelImpls...>
T
tensor-tang 已提交
57 58 59 60 61
        func;
    func(kt);
  }
};

T
tensor-tang 已提交
62
template <typename Pool, typename PlaceType, typename... KernelImpls>
T
tensor-tang 已提交
63 64 65
class JitKernelRegistrar {
 public:
  explicit JitKernelRegistrar(KernelType kt) {
T
tensor-tang 已提交
66
    JitKernelRegistrarFunctor<Pool, PlaceType, false, 0, KernelImpls...> func;
T
tensor-tang 已提交
67 68
    func(kt);
  }
T
tensor-tang 已提交
69
  void Touch() {}
T
tensor-tang 已提交
70 71 72 73 74 75 76 77
};

#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg)              \
  struct __test_global_namespace_##uniq_name##__ {};                          \
  static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                             __test_global_namespace_##uniq_name##__>::value, \
                msg)

T
tensor-tang 已提交
78
// Refer always on CPUPlace
T
tensor-tang 已提交
79 80 81 82 83 84 85 86 87 88 89 90
#define REGISTER_JITKERNEL_REFER(kernel_type, ...)                             \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                                    \
      __reg_jitkernel_##kernel_type##_refer_CPUPlace,                          \
      "REGISTER_KERNEL_REFER must be called in global namespace");             \
  static ::paddle::operators::jit::JitKernelRegistrar<                         \
      ::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \
      __VA_ARGS__>                                                             \
      __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_(                  \
          ::paddle::operators::jit::KernelType::kernel_type);                  \
  int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() {                    \
    __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch();            \
    return 0;                                                                  \
T
tensor-tang 已提交
91 92
  }

T
tensor-tang 已提交
93
// kernel_type: should be in paddle::operators::jit::KernelType
T
tensor-tang 已提交
94
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
T
tensor-tang 已提交
95 96 97 98 99 100 101
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...)         \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                                   \
      __reg_jitkernel_##kernel_type##_##impl_type##_##place_type,             \
      "REGISTER_KERNEL_MORE must be called in global namespace");             \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();             \
  static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
      UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();           \
T
tensor-tang 已提交
102 103 104
  static ::paddle::operators::jit::JitKernelRegistrar<                        \
      ::paddle::operators::jit::KernelPool, ::paddle::platform::place_type,   \
      __VA_ARGS__>                                                            \
T
tensor-tang 已提交
105
      __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_(   \
T
tensor-tang 已提交
106
          ::paddle::operators::jit::KernelType::kernel_type);                 \
T
tensor-tang 已提交
107 108 109 110 111
  int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() {     \
    __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_      \
        .Touch();                                                             \
    return 0;                                                                 \
  }
T
tensor-tang 已提交
112 113 114 115 116 117 118

#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
  REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)

#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
  REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)

T
tensor-tang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
#define REGISTER_JITKERNEL_GEN(kernel_type, ...)                    \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_gen_##kernel_type##_CPUPlace_,                \
      "REGISTER_JITKERNEL_GEN must be called in global namespace"); \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();   \
  static int __assert_gen_##kernel_type##_has_refer_ UNUSED =       \
      TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();          \
  static ::paddle::operators::jit::JitKernelRegistrar<              \
      ::paddle::operators::jit::JitCodeCreatorPool,                 \
      ::paddle::platform::CPUPlace, __VA_ARGS__>                    \
      __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_(         \
          ::paddle::operators::jit::KernelType::kernel_type);       \
  int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() {           \
    __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch();   \
    return 0;                                                       \
  }

#define USE_JITKERNEL_GEN(kernel_type)                            \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                       \
      __reg_jitkernel_gen_##kernel_type##_CPUPlace_,              \
      "USE_JITKERNEL_GEN must be called in global namespace");    \
  extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_();   \
  static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
      TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
T
tensor-tang 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

#define USE_JITKERNEL_REFER(kernel_type)                            \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                         \
      __reg_jitkernel_##kernel_type##_refer_CPUPlace_,              \
      "USE_JITKERNEL_REFER must be called in global namespace");    \
  extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_();   \
  static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
      TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()

#define USE_KERNEL_MORE(kernel_type, impl_type, place_type)              \
  STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(                              \
      __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_,     \
      "USE_JITKERNEL_MORE must be called in global namespace");          \
  extern int                                                             \
      TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
  static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
      UNUSED =                                                           \
          TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()

#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
  USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
T
tensor-tang 已提交
164

T
tensor-tang 已提交
165
}  // namespace jit
T
tensor-tang 已提交
166 167
}  // namespace operators
}  // namespace paddle