kernel_utils.h 19.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/phi/backends/all_context.h"
18
#include "paddle/phi/common/int_array.h"
19 20 21
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
22
#include "paddle/phi/core/extended_tensor.h"
23 24 25 26
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
J
Jack Zhou 已提交
27
#include "paddle/phi/core/string_tensor.h"
28
#include "paddle/phi/core/tensor_array.h"
29
#include "paddle/phi/core/type_defs.h"
30

31
namespace phi {
32

33 34
// PD_KERNEL has been used by custom op api
#define PHI_KERNEL(...) \
35
  ::phi::KernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute
36

37
#define PHI_VARIADIC_KERNEL(...)                                     \
38 39
  reinterpret_cast<void*>(&::phi::KernelImpl<decltype(&__VA_ARGS__), \
                                             &__VA_ARGS__>::VariadicCompute)
40

41
#define PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx)           \
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
  template <typename... Tail>                                                \
  struct KernelCallHelper<const dev_ctx&, Tail...> {                         \
    template <int dev_ctx_idx,                                               \
              int in_idx,                                                    \
              int attr_idx,                                                  \
              int out_idx,                                                   \
              typename... PreviousArgs>                                      \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {        \
      static_assert(in_idx == 0,                                             \
                    "Kernel's DeviceContext should appear before Inputs.");  \
      static_assert(                                                         \
          attr_idx == 0,                                                     \
          "Kernel's DeviceContext should appear before Attributes.");        \
      static_assert(out_idx == 0,                                            \
                    "Kernel's DeviceContext should appear before Outputs."); \
      const dev_ctx& arg = ctx->GetDeviceContext<dev_ctx>();                 \
      KernelCallHelper<Tail...>::                                            \
          template Compute<dev_ctx_idx + 1, in_idx, attr_idx, out_idx>(      \
              ctx, pargs..., arg);                                           \
    }                                                                        \
  }

64
#define PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(tensor_type)           \
65 66 67 68 69 70 71 72 73 74 75 76
  template <typename... Tail>                                           \
  struct KernelCallHelper<const tensor_type&, Tail...> {                \
    template <int dev_ctx_idx,                                          \
              int in_idx,                                               \
              int attr_idx,                                             \
              int out_idx,                                              \
              typename... PreviousArgs>                                 \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {   \
      static_assert(attr_idx == 0,                                      \
                    "Kernel's Input should appear before Attributes."); \
      static_assert(out_idx == 0,                                       \
                    "Kernel's Input should appear before Outputs.");    \
77
      const std::pair<int, int>& range = ctx->InputRangeAt(in_idx);     \
78 79 80 81 82 83 84
      const tensor_type& arg = ctx->InputAt<tensor_type>(range.first);  \
      KernelCallHelper<Tail...>::                                       \
          template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
              ctx, pargs..., arg);                                      \
    }                                                                   \
  }

85
#define PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type)     \
86
  template <typename... Tail>                                              \
87
  struct KernelCallHelper<const paddle::optional<tensor_type>&, Tail...> { \
88 89 90 91 92 93 94 95 96 97
    template <int dev_ctx_idx,                                             \
              int in_idx,                                                  \
              int attr_idx,                                                \
              int out_idx,                                                 \
              typename... PreviousArgs>                                    \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {      \
      static_assert(attr_idx == 0,                                         \
                    "Kernel's Input should appear before Attributes.");    \
      static_assert(out_idx == 0,                                          \
                    "Kernel's Input should appear before Outputs.");       \
98
      const std::pair<int, int>& range = ctx->InputRangeAt(in_idx);        \
99 100 101 102 103 104 105
      auto arg = ctx->OptionalInputAt<tensor_type>(range.first);           \
      KernelCallHelper<Tail...>::                                          \
          template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>(    \
              ctx, pargs..., arg);                                         \
    }                                                                      \
  }

106
#define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type)          \
107 108 109 110 111 112 113 114 115 116 117 118
  template <typename... Tail>                                                \
  struct KernelCallHelper<const std::vector<const tensor_type*>&, Tail...> { \
    template <int dev_ctx_idx,                                               \
              int in_idx,                                                    \
              int attr_idx,                                                  \
              int out_idx,                                                   \
              typename... PreviousArgs>                                      \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {        \
      static_assert(attr_idx == 0,                                           \
                    "Kernel's Input should appear before Attributes.");      \
      static_assert(out_idx == 0,                                            \
                    "Kernel's Input should appear before Outputs.");         \
119
      const std::pair<int, int>& range = ctx->InputRangeAt(in_idx);          \
120 121 122 123 124 125
      std::vector<const tensor_type*> arg = std::move(                       \
          ctx->InputsBetween<tensor_type>(range.first, range.second));       \
      KernelCallHelper<Tail...>::                                            \
          template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>(      \
              ctx, pargs..., arg);                                           \
    }                                                                        \
126 127
  }

128 129 130
#define PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(tensor_type)  \
  template <typename... Tail>                                                 \
  struct KernelCallHelper<                                                    \
131
      const paddle::optional<std::vector<const tensor_type*>>&,               \
132 133 134 135 136 137 138 139 140 141 142
      Tail...> {                                                              \
    template <int dev_ctx_idx,                                                \
              int in_idx,                                                     \
              int attr_idx,                                                   \
              int out_idx,                                                    \
              typename... PreviousArgs>                                       \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {         \
      static_assert(attr_idx == 0,                                            \
                    "Kernel's Input should appear before Attributes.");       \
      static_assert(out_idx == 0,                                             \
                    "Kernel's Input should appear before Outputs.");          \
143
      const std::pair<int, int>& range = ctx->InputRangeAt(in_idx);           \
144
      paddle::optional<std::vector<const tensor_type*>> arg =                 \
145 146 147 148 149 150 151
          ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \
      KernelCallHelper<Tail...>::                                             \
          template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>(       \
              ctx, pargs..., arg);                                            \
    }                                                                         \
  }

152
#define PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type)           \
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  template <typename... Tail>                                             \
  struct KernelCallHelper<attr_type, Tail...> {                           \
    template <int dev_ctx_idx,                                            \
              int in_idx,                                                 \
              int attr_idx,                                               \
              int out_idx,                                                \
              typename... PreviousArgs>                                   \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {     \
      static_assert(out_idx == 0,                                         \
                    "Kernel's Attributes should appear before Outputs."); \
      attr_type arg = ctx->AttrAt<attr_type>(attr_idx);                   \
      KernelCallHelper<Tail...>::                                         \
          template Compute<dev_ctx_idx, in_idx, attr_idx + 1, out_idx>(   \
              ctx, pargs..., arg);                                        \
    }                                                                     \
  }

170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
#define PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
  template <typename... Tail>                                             \
  struct KernelCallHelper<const attr_type&, Tail...> {                    \
    template <int dev_ctx_idx,                                            \
              int in_idx,                                                 \
              int attr_idx,                                               \
              int out_idx,                                                \
              typename... PreviousArgs>                                   \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {     \
      static_assert(out_idx == 0,                                         \
                    "Kernel's Attributes should appear before Outputs."); \
      const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx);            \
      KernelCallHelper<Tail...>::                                         \
          template Compute<dev_ctx_idx, in_idx, attr_idx + 1, out_idx>(   \
              ctx, pargs..., arg);                                        \
    }                                                                     \
  }

188
#define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type)           \
189 190 191 192 193 194 195 196
  template <typename... Tail>                                            \
  struct KernelCallHelper<tensor_type*, Tail...> {                       \
    template <int dev_ctx_idx,                                           \
              int in_idx,                                                \
              int attr_idx,                                              \
              int out_idx,                                               \
              typename... PreviousArgs>                                  \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {    \
197
      const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx);    \
198 199 200 201 202 203 204
      tensor_type* arg = ctx->MutableOutputAt<tensor_type>(range.first); \
      KernelCallHelper<Tail...>::                                        \
          template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>(  \
              ctx, pargs..., arg);                                       \
    }                                                                    \
  }

205
#define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type)          \
206 207 208 209 210 211 212 213
  template <typename... Tail>                                                 \
  struct KernelCallHelper<std::vector<tensor_type*>, Tail...> {               \
    template <int dev_ctx_idx,                                                \
              int in_idx,                                                     \
              int attr_idx,                                                   \
              int out_idx,                                                    \
              typename... PreviousArgs>                                       \
    static void Compute(KernelContext* ctx, PreviousArgs&... pargs) {         \
214
      const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx);         \
215 216 217 218 219 220
      std::vector<tensor_type*> arg = std::move(                              \
          ctx->MutableOutputBetween<tensor_type>(range.first, range.second)); \
      KernelCallHelper<Tail...>::                                             \
          template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>(       \
              ctx, pargs..., arg);                                            \
    }                                                                         \
221 222 223 224 225 226 227 228
  }

template <typename T>
struct TypeTag {};

template <typename Fn, Fn fn>
struct KernelImpl;

229 230 231 232 233
template <typename Return,
          typename DevCtx,
          typename... Args,
          Return (*kernel_fn)(DevCtx, Args...)>
struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
234
  static void Compute(KernelContext* ctx) {
235 236
    KernelCallHelper<DevCtx, Args..., TypeTag<int>>::
        template Compute<0, 0, 0, 0>(ctx);
237 238 239 240
  }

  static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) {
    return kernel_fn(static_cast<DevCtx>(dev_ctx), std::forward<Args>(args)...);
241 242 243 244 245 246 247 248
  }

 private:
  template <typename... RemainingArgs>
  struct KernelCallHelper;

  /* DeviceContext Helpers */

249
  PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CPUContext);
250
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
251
  PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(GPUContext);
252 253
#endif
#ifdef PADDLE_WITH_XPU
254
  PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext);
255
#endif
256
#ifdef PADDLE_WITH_CUSTOM_DEVICE
257
  PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CustomContext);
258
#endif
259 260 261
#ifdef PADDLE_WITH_MKLDNN
  PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(OneDNNContext);
#endif
262 263
  /* Input Helpers */

264 265 266 267
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
268
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(ExtendedTensor);
269
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(ExtendedTensor);
Y
YuanRisheng 已提交
270 271
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows);
272
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
273
  PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(DenseTensor);
274

275 276 277
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCooTensor);
278

279 280 281
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCsrTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCsrTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCsrTensor);
282

J
Jack Zhou 已提交
283 284 285 286
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(StringTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(StringTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(StringTensor);

287 288 289
  PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(TensorArray);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorArray);

290 291
  /* Attribute Helpers */

292 293 294 295 296 297 298 299 300
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(float);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
  PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place);
301 302 303 304 305 306 307 308 309 310 311
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<bool>);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int64_t>);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<float>);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<double>);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(
      std::vector<std::string>);
  PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<Scalar>);
312 313 314

  /* Output Helpers */

315 316 317
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);
318

319 320
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);
321

322 323
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCsrTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCsrTensor);
324

J
Jack Zhou 已提交
325 326
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor);
  PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor);
327

328
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray);
329
  PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(ExtendedTensor);
330

331 332 333 334
  /* End case */
  template <typename T>
  struct KernelCallHelper<TypeTag<T>> {
    template <int dev_ctx_idx, int in_idx, int attr_idx, int out_idx>
335 336 337
    static void Compute(KernelContext* ctx UNUSED,
                        DevCtx dev_ctx,
                        Args&... args) {
338 339
      static_assert(dev_ctx_idx > 0,
                    "Kernel should pass DeviceContext as argument.");
340
      return kernel_fn(dev_ctx, args...);
341 342 343 344
    }
  };
};

W
wanghuancoder 已提交
345 346 347
inline bool recompute_reduce_all(const DenseTensor& x,
                                 const IntArray& dims,
                                 bool reduce_all = false) {
zhouweiwei2014's avatar
zhouweiwei2014 已提交
348 349 350
  if (dims.size() == 0 || x.dims().size() == 0 ||
      static_cast<int>(dims.size()) == x.dims().size() || reduce_all) {
    // when input 0D, it can only reduce_all
W
wanghuancoder 已提交
351 352 353 354 355 356
    return true;
  } else {
    return false;
  }
}

357
}  // namespace phi