kernel_base.h 9.9 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#pragma once
16
#include <cstdint>
17
#include "paddle/fluid/operators/jit/macro.h"
T
tensor-tang 已提交
18 19 20 21
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace operators {
T
tensor-tang 已提交
22
namespace jit {
T
tensor-tang 已提交
23

24
typedef enum {
T
tensor-tang 已提交
25
  kNone = 0,
26
  // sort by alphabet
27 28 29
  kAdam = 1,
  kCRFDecoding,
  kEmbSeqPool,
T
tensor-tang 已提交
30 31 32
  kGRUH1,
  kGRUHtPart1,
  kGRUHtPart2,
33 34 35 36
  kHSum,  // horizontal max
  kHMax,  // horizontal sum
  kLSTMCtHt,
  kLSTMC1H1,
T
tensor-tang 已提交
37
  kLayerNorm,
38
  kMatMul,
T
tensor-tang 已提交
39
  kNCHW16CMulNC,
T
tensor-tang 已提交
40
  kSeqPool,
41
  kSoftmax,
D
dengkaipeng 已提交
42 43
  kStrideASum,
  kStrideScal,
44 45 46
  kVAdd,
  kVAddBias,
  kVAddRelu,
47
  kVBroadcast,
48
  kVCopy,
49 50 51 52 53
  kVExp,
  kVIdentity,
  kVMul,
  kVRelu,
  kVScal,
54
  kSgd,
55 56 57 58
  kVSigmoid,
  kVSquare,
  kVSub,
  kVTanh,
59
} KernelType;
T
tensor-tang 已提交
60

61 62
typedef enum {
  kNonePoolType = 0,
T
tensor-tang 已提交
63
  kSum = 1,
64 65 66 67
  kAvg,
  kSqrt,
} SeqPoolType;

68
// x, y, z, n
T
tensor-tang 已提交
69
template <typename T>
70
struct XYZNTuple {
T
tensor-tang 已提交
71 72 73 74 75
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(const T*, const T*, T*, int);
};

76
// a, x, y, n
77
template <typename T>
78
struct AXYNTuple : public XYZNTuple<T> {};
79

80 81 82 83 84 85 86 87
// a, x, y, n, stride
template <typename T>
struct AXYNSTuple {
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(const T*, const T*, T*, int, int);
};

88
// x, y, n
89
template <typename T>
90
struct XYNTuple {
91 92 93 94 95
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(const T*, T*, int);
};

96
// x, returned value, n
97
template <typename T>
98 99
struct XRNTuple : public XYNTuple<T> {};

100 101 102 103 104 105 106 107
// x, returned value, n, stride
template <typename T>
struct XRNSTuple {
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(const T*, T*, int, int);
};

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
#define DECLARE_KERNELTUPLE(kernel_tuple, type)        \
  template <typename T>                                \
  struct type##Tuple : public kernel_tuple<T> {        \
    static constexpr KernelType kernel_type = k##type; \
  }

// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);

DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);

123 124
DECLARE_KERNELTUPLE(AXYNSTuple, StrideScal);

125 126 127 128 129 130 131 132 133 134
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);

DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
135

D
dengkaipeng 已提交
136
DECLARE_KERNELTUPLE(XRNSTuple, StrideASum);
137

T
tensor-tang 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
typedef struct {
  void* gates;  // gates: x_ch, x_ih, x_fh, x_oh
  const void* ct_1;
  void* ct;
  void* ht;
  /* weight_peephole and checked data are only used in peephole*/
  const void* wp{nullptr};  //  W_ic, W_fc, W_oc
  void* checked{nullptr};   // size: 2 * d
} lstm_t;

typedef struct {
  void* gates;  // gates: {x_update, x_reset; x_state}
  const void* ht_1;
  void* ht;
} gru_t;

struct rnn_attr_s {
  int d;
  KernelType act_gate, act_cand;
  rnn_attr_s() = default;
T
tensor-tang 已提交
158
  explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
T
tensor-tang 已提交
159 160 161 162 163 164 165
      : d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};

struct lstm_attr_s : public rnn_attr_s {
  bool use_peephole;
  KernelType act_cell;
  lstm_attr_s() = default;
T
tensor-tang 已提交
166 167
  explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
                       KernelType _act_cell, bool _use_peephole = false)
T
tensor-tang 已提交
168 169 170 171 172 173 174 175 176
      : rnn_attr_s(_d, _act_gate, _act_cand),
        use_peephole(_use_peephole),
        act_cell(_act_cell) {}
};

typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;

template <typename T>
177
struct LSTMTuple {
T
tensor-tang 已提交
178 179 180 181 182
  typedef T data_type;
  typedef lstm_attr_t attr_type;
  typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};

183
template <typename T>
184
struct GRUTuple {
185 186 187 188 189
  typedef T data_type;
  typedef gru_attr_t attr_type;
  typedef void (*func_type)(gru_t*, const gru_attr_t*);
};

190 191 192 193 194 195 196 197 198
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);

DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);

#undef DECLARE_KERNELTUPLE

199
template <typename T>
200 201
struct VBroadcastTuple {
  static constexpr KernelType kernel_type = kVBroadcast;
202 203 204 205 206
  typedef T data_type;
  typedef int64_t attr_type;
  typedef void (*func_type)(const T*, T*, int64_t, int64_t);
};

207
typedef struct seq_pool_attr_s {
T
tensor-tang 已提交
208
  int h, w;  // h should always be the first one
T
tensor-tang 已提交
209
  SeqPoolType type;
210
  seq_pool_attr_s() = default;
T
tensor-tang 已提交
211
  explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1)
212
      : h(height), w(width), type(pool_type) {}
T
tensor-tang 已提交
213 214 215
} seq_pool_attr_t;

template <typename T>
216 217
struct SeqPoolTuple {
  static constexpr KernelType kernel_type = kSeqPool;
T
tensor-tang 已提交
218 219 220 221 222
  typedef T data_type;
  typedef seq_pool_attr_t attr_type;
  typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
typedef struct emb_seq_pool_attr_s {
  int64_t table_height, table_width;
  int64_t index_height, index_width;
  int64_t out_width;
  SeqPoolType pool_type;
  emb_seq_pool_attr_s() = default;
  explicit emb_seq_pool_attr_s(int64_t tbl_height, int64_t tbl_width,
                               int64_t idx_height, int64_t idx_width,
                               int64_t output_width,
                               SeqPoolType seqpool_type = SeqPoolType::kSum)
      : table_height(tbl_height),
        table_width(tbl_width),
        index_height(idx_height),
        index_width(idx_width),
        out_width(output_width),
        pool_type(seqpool_type) {}
} emb_seq_pool_attr_t;

template <typename T>
242 243
struct EmbSeqPoolTuple {
  static constexpr KernelType kernel_type = kEmbSeqPool;
244 245 246 247 248 249
  typedef T data_type;
  typedef emb_seq_pool_attr_t attr_type;
  typedef void (*func_type)(const T*, const int64_t*, T*,
                            const emb_seq_pool_attr_t*);
};

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
typedef struct sgd_attr_s {
  int64_t param_height, param_width;
  int64_t grad_height, grad_width;
  int64_t selected_rows_size;
  sgd_attr_s() = default;
  explicit sgd_attr_s(int64_t param_h, int64_t param_w, int64_t grad_h,
                      int64_t grad_w, int64_t selected_rows_sz)
      : param_height(param_h),
        param_width(param_w),
        grad_height(grad_h),
        grad_width(grad_w),
        selected_rows_size(selected_rows_sz) {}
} sgd_attr_t;

template <typename T>
265 266
struct SgdTuple {
  static constexpr KernelType kernel_type = kSgd;
267 268 269 270 271 272
  typedef T data_type;
  typedef sgd_attr_t attr_type;
  typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
                            const sgd_attr_t*);
};

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
typedef struct adam_attr_s {
  float beta1, beta2;
  adam_attr_s() = default;
  explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {}
} adam_attr_t;

template <typename T>
struct AdamTuple {
  static constexpr KernelType kernel_type = kAdam;
  typedef T data_type;
  typedef adam_attr_t attr_type;
  typedef void (*func_type)(T, T, T, T, int64_t, const T*, const T*, const T*,
                            const T*, T*, T*, T*);
};

288 289 290 291 292 293 294 295
typedef struct matmul_attr_s {
  int m, n, k;
  void* packed_weight{nullptr};
  matmul_attr_s() = default;
  explicit matmul_attr_s(int m_, int n_, int k_, void* packed_weight_ = nullptr)
      : m(m_), n(n_), k(k_), packed_weight(packed_weight_) {}
} matmul_attr_t;

T
tensor-tang 已提交
296
template <typename T>
297 298
struct MatMulTuple {
  static constexpr KernelType kernel_type = kMatMul;
T
tensor-tang 已提交
299
  typedef T data_type;
300 301
  typedef matmul_attr_t attr_type;
  typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
T
tensor-tang 已提交
302 303
};

304
template <typename T>
305 306
struct CRFDecodingTuple {
  static constexpr KernelType kernel_type = kCRFDecoding;
307 308 309 310 311 312
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};

template <typename T>
313 314
struct LayerNormTuple {
  static constexpr KernelType kernel_type = kLayerNorm;
315 316 317 318 319 320
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
                            const float, int);
};

321
template <typename T>
322 323
struct SoftmaxTuple {
  static constexpr KernelType kernel_type = kSoftmax;
324 325
  typedef T data_type;
  typedef int attr_type;
326
  typedef void (*func_type)(const T*, T*, int, int, int);
327 328
};

T
tensor-tang 已提交
329 330
// nChw16c = nChw16c .* NC
template <typename T>
331 332
struct NCHW16CMulNCTuple {
  static constexpr KernelType kernel_type = kNCHW16CMulNC;
T
tensor-tang 已提交
333 334 335 336 337
  typedef T data_type;
  typedef int attr_type;
  typedef void (*func_type)(const T*, const T*, T*, int, int);
};

T
tensor-tang 已提交
338 339 340 341
// Just for adding to kernel pool without template
class Kernel {
 public:
  Kernel() = default;
T
tensor-tang 已提交
342
  virtual ~Kernel() = default;
343
  virtual const char* ImplType() const = 0;
T
tensor-tang 已提交
344 345 346
  DISABLE_COPY_AND_ASSIGN(Kernel);
};

347
template <typename KernelTuple>
T
tensor-tang 已提交
348
class KernelMore : public Kernel {
349
 public:
350 351 352
  using T = typename KernelTuple::data_type;
  using Func = typename KernelTuple::func_type;
  using Attr = typename KernelTuple::attr_type;
T
tensor-tang 已提交
353
  virtual Func GetFunc() const { return func; }
354 355
  // specify this kernel can be used, means it should not fail if use it.
  virtual bool CanBeUsed(const Attr& attr) const = 0;
T
tensor-tang 已提交
356 357 358 359 360

 protected:
  Func func{nullptr};
};

361 362
template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuple> {
T
tensor-tang 已提交
363 364
 public:
  // Refer code can always be used
365
  bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
T
tensor-tang 已提交
366 367
    return true;
  }
T
tensor-tang 已提交
368
  const char* ImplType() const override { return "Refer"; }
T
tensor-tang 已提交
369 370
};

T
tensor-tang 已提交
371
}  // namespace jit
T
tensor-tang 已提交
372 373
}  // namespace operators
}  // namespace paddle