hl_tensor_ops.h 10.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
H
hedaoyuan 已提交
14 15 16 17 18 19 20 21 22 23

#ifndef HL_TENSOR_OPS_H_
#define HL_TENSOR_OPS_H_

#include <cmath>
#include "hl_matrix_type.cuh"

namespace hppl {
namespace unary {

H
hedaoyuan 已提交
24 25
template <class T>
class add_scale {
W
Wu Yi 已提交
26
 private:
H
hedaoyuan 已提交
27
  const T p;
H
hedaoyuan 已提交
28

W
Wu Yi 已提交
29
 public:
H
hedaoyuan 已提交
30 31 32 33
  INLINE add_scale(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return a + p; }
};

H
hedaoyuan 已提交
34
template <class T>
H
hedaoyuan 已提交
35
class sub_scale {
W
Wu Yi 已提交
36
 private:
H
hedaoyuan 已提交
37
  const T p;
H
hedaoyuan 已提交
38

W
Wu Yi 已提交
39
 public:
H
hedaoyuan 已提交
40 41 42 43
  INLINE sub_scale(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return a - p; }
};

H
hedaoyuan 已提交
44
template <class T>
H
hedaoyuan 已提交
45
class mul_scale {
W
Wu Yi 已提交
46
 private:
H
hedaoyuan 已提交
47
  const T p;
H
hedaoyuan 已提交
48

W
Wu Yi 已提交
49
 public:
H
hedaoyuan 已提交
50 51 52 53
  INLINE mul_scale(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return a * p; }
};

H
hedaoyuan 已提交
54
template <class T>
H
hedaoyuan 已提交
55
class div_scale {
W
Wu Yi 已提交
56
 private:
H
hedaoyuan 已提交
57
  const T p;
H
hedaoyuan 已提交
58

W
Wu Yi 已提交
59
 public:
H
hedaoyuan 已提交
60 61 62 63
  INLINE div_scale(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return a / p; }
};

H
hedaoyuan 已提交
64
template <class T>
H
hedaoyuan 已提交
65
class neg {
W
Wu Yi 已提交
66
 public:
H
hedaoyuan 已提交
67 68 69
  INLINE T operator()(const T a) const { return -a; }
};

H
hedaoyuan 已提交
70
template <class T>
H
hedaoyuan 已提交
71
class exp_op {
W
Wu Yi 已提交
72
 public:
H
hedaoyuan 已提交
73 74 75
  INLINE T operator()(const T a) const { return std::exp(a); }
};

H
hedaoyuan 已提交
76
template <class T>
H
hedaoyuan 已提交
77
class log_op {
W
Wu Yi 已提交
78
 public:
H
hedaoyuan 已提交
79 80 81
  INLINE T operator()(const T a) const { return std::log(a); }
};

H
hedaoyuan 已提交
82
template <class T>
H
hedaoyuan 已提交
83
class sqrt_op {
W
Wu Yi 已提交
84
 public:
H
hedaoyuan 已提交
85 86 87
  INLINE T operator()(const T a) const { return std::sqrt(a); }
};

H
hedaoyuan 已提交
88
template <class T>
H
hedaoyuan 已提交
89
class square {
W
Wu Yi 已提交
90
 public:
H
hedaoyuan 已提交
91 92 93
  INLINE T operator()(const T a) const { return a * a; }
};

H
hedaoyuan 已提交
94
template <class T>
H
hedaoyuan 已提交
95
class reciprocal {
W
Wu Yi 已提交
96
 public:
H
hedaoyuan 已提交
97 98 99
  INLINE T operator()(const T a) const { return T(1) / a; }
};

H
hedaoyuan 已提交
100
template <class T>
H
hedaoyuan 已提交
101
class abs {
W
Wu Yi 已提交
102
 public:
H
hedaoyuan 已提交
103 104 105
  INLINE T operator()(const T a) const { return a > 0 ? a : -a; }
};

H
hedaoyuan 已提交
106
template <class T>
H
hedaoyuan 已提交
107
class sign {
W
Wu Yi 已提交
108
 public:
H
hedaoyuan 已提交
109 110 111
  INLINE T operator()(const T a) const { return (a > 0) - (a < 0); }
};

H
hedaoyuan 已提交
112
template <class T>
H
hedaoyuan 已提交
113
class min {
W
Wu Yi 已提交
114
 private:
H
hedaoyuan 已提交
115
  const T p;
H
hedaoyuan 已提交
116

W
Wu Yi 已提交
117
 public:
H
hedaoyuan 已提交
118 119 120 121
  INLINE min(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return a > p ? p : a; }
};

H
hedaoyuan 已提交
122
template <class T>
H
hedaoyuan 已提交
123
class max {
W
Wu Yi 已提交
124
 private:
H
hedaoyuan 已提交
125
  const T p;
H
hedaoyuan 已提交
126

W
Wu Yi 已提交
127
 public:
H
hedaoyuan 已提交
128 129 130 131
  INLINE max(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return a < p ? p : a; }
};

H
hedaoyuan 已提交
132
template <class T>
H
hedaoyuan 已提交
133
class pow_op {
W
Wu Yi 已提交
134
 private:
H
hedaoyuan 已提交
135
  const T p;
H
hedaoyuan 已提交
136

W
Wu Yi 已提交
137
 public:
H
hedaoyuan 已提交
138 139 140 141
  INLINE pow_op(const T s) : p(s) {}
  INLINE T operator()(const T a) const { return std::pow(a, p); }
};

H
hedaoyuan 已提交
142
template <class T>
H
hedaoyuan 已提交
143
class constant {
W
Wu Yi 已提交
144
 private:
H
hedaoyuan 已提交
145
  const T p;
H
hedaoyuan 已提交
146

W
Wu Yi 已提交
147
 public:
H
hedaoyuan 已提交
148 149 150 151 152
  INLINE constant(const T s) : p(s) {}
  INLINE T operator()(int i) const { return p; }
  INLINE T operator()(int i, int j) const { return p; }
};

H
hedaoyuan 已提交
153
template <class T>
H
hedaoyuan 已提交
154
class cmp_eq {
W
Wu Yi 已提交
155
 private:
H
hedaoyuan 已提交
156
  const T p;
H
hedaoyuan 已提交
157

W
Wu Yi 已提交
158
 public:
H
hedaoyuan 已提交
159 160 161 162
  INLINE cmp_eq(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a == p; }
};

H
hedaoyuan 已提交
163
template <class T>
H
hedaoyuan 已提交
164
class cmp_ne {
W
Wu Yi 已提交
165
 private:
H
hedaoyuan 已提交
166
  const T p;
H
hedaoyuan 已提交
167

W
Wu Yi 已提交
168
 public:
H
hedaoyuan 已提交
169 170 171 172
  INLINE cmp_ne(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a != p; }
};

H
hedaoyuan 已提交
173
template <class T>
H
hedaoyuan 已提交
174
class cmp_le {
W
Wu Yi 已提交
175
 private:
H
hedaoyuan 已提交
176
  const T p;
H
hedaoyuan 已提交
177

W
Wu Yi 已提交
178
 public:
H
hedaoyuan 已提交
179 180 181 182
  INLINE cmp_le(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a <= p; }
};

H
hedaoyuan 已提交
183
template <class T>
H
hedaoyuan 已提交
184
class cmp_lt {
W
Wu Yi 已提交
185
 private:
H
hedaoyuan 已提交
186
  const T p;
H
hedaoyuan 已提交
187

W
Wu Yi 已提交
188
 public:
H
hedaoyuan 已提交
189 190 191 192
  INLINE cmp_lt(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a < p; }
};

H
hedaoyuan 已提交
193
template <class T>
H
hedaoyuan 已提交
194
class cmp_ge {
W
Wu Yi 已提交
195
 private:
H
hedaoyuan 已提交
196
  const T p;
H
hedaoyuan 已提交
197

W
Wu Yi 已提交
198
 public:
H
hedaoyuan 已提交
199 200 201 202
  INLINE cmp_ge(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a >= p; }
};

H
hedaoyuan 已提交
203
template <class T>
H
hedaoyuan 已提交
204
class cmp_gt {
W
Wu Yi 已提交
205
 private:
H
hedaoyuan 已提交
206
  const T p;
H
hedaoyuan 已提交
207

W
Wu Yi 已提交
208
 public:
H
hedaoyuan 已提交
209 210 211 212
  INLINE cmp_gt(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a > p; }
};

H
hedaoyuan 已提交
213
template <class T>
H
hedaoyuan 已提交
214
class and_op {
W
Wu Yi 已提交
215
 private:
H
hedaoyuan 已提交
216
  const T p;
H
hedaoyuan 已提交
217

W
Wu Yi 已提交
218
 public:
H
hedaoyuan 已提交
219 220 221 222
  INLINE and_op(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a && p; }
};

H
hedaoyuan 已提交
223
template <class T>
H
hedaoyuan 已提交
224
class or_op {
W
Wu Yi 已提交
225
 private:
H
hedaoyuan 已提交
226
  const T p;
H
hedaoyuan 已提交
227

W
Wu Yi 已提交
228
 public:
H
hedaoyuan 已提交
229 230 231 232 233 234 235
  INLINE or_op(const T s) : p(s) {}
  INLINE bool operator()(const T a) const { return a || p; }
};

}  // namespace unary

namespace binary {
H
hedaoyuan 已提交
236
template <class T>
H
hedaoyuan 已提交
237
class add {
W
Wu Yi 已提交
238
 public:
H
hedaoyuan 已提交
239 240 241
  INLINE T operator()(const T a, const T b) const { return a + b; }
};

H
hedaoyuan 已提交
242
template <class T>
H
hedaoyuan 已提交
243
class add_scale {
W
Wu Yi 已提交
244
 private:
H
hedaoyuan 已提交
245 246
  const T p1;
  const T p2;
H
hedaoyuan 已提交
247

W
Wu Yi 已提交
248
 public:
H
hedaoyuan 已提交
249
  INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {}
H
hedaoyuan 已提交
250
  INLINE T operator()(const T a, const T b) const { return p1 * a + p2 * b; }
H
hedaoyuan 已提交
251 252
};

H
hedaoyuan 已提交
253
template <class T>
H
hedaoyuan 已提交
254
class sub {
W
Wu Yi 已提交
255
 public:
H
hedaoyuan 已提交
256 257 258
  INLINE T operator()(const T a, const T b) const { return a - b; }
};

H
hedaoyuan 已提交
259
template <class T>
H
hedaoyuan 已提交
260
class mul {
W
Wu Yi 已提交
261
 public:
H
hedaoyuan 已提交
262 263 264
  INLINE T operator()(const T a, const T b) const { return a * b; }
};

H
hedaoyuan 已提交
265
template <class T>
H
hedaoyuan 已提交
266
class div {
W
Wu Yi 已提交
267
 public:
H
hedaoyuan 已提交
268
  INLINE T operator()(const T a, const T b) const { return a / b; }
H
hedaoyuan 已提交
269 270
};

H
hedaoyuan 已提交
271
template <class T>
H
hedaoyuan 已提交
272
class cmp_eq {
W
Wu Yi 已提交
273
 public:
H
hedaoyuan 已提交
274 275 276
  INLINE bool operator()(const T a, const T b) const { return a == b; }
};

H
hedaoyuan 已提交
277
template <class T>
H
hedaoyuan 已提交
278
class cmp_ne {
W
Wu Yi 已提交
279
 public:
H
hedaoyuan 已提交
280 281 282
  INLINE bool operator()(const T a, const T b) const { return a != b; }
};

H
hedaoyuan 已提交
283
template <class T>
H
hedaoyuan 已提交
284
class cmp_le {
W
Wu Yi 已提交
285
 public:
H
hedaoyuan 已提交
286 287 288
  INLINE bool operator()(const T a, const T b) const { return a <= b; }
};

H
hedaoyuan 已提交
289
template <class T>
H
hedaoyuan 已提交
290
class cmp_lt {
W
Wu Yi 已提交
291
 public:
H
hedaoyuan 已提交
292 293 294
  INLINE bool operator()(const T a, const T b) const { return a < b; }
};

H
hedaoyuan 已提交
295
template <class T>
H
hedaoyuan 已提交
296
class cmp_ge {
W
Wu Yi 已提交
297
 public:
H
hedaoyuan 已提交
298 299 300
  INLINE bool operator()(const T a, const T b) const { return a >= b; }
};

H
hedaoyuan 已提交
301
template <class T>
H
hedaoyuan 已提交
302
class cmp_gt {
W
Wu Yi 已提交
303
 public:
H
hedaoyuan 已提交
304 305 306
  INLINE bool operator()(const T a, const T b) const { return a > b; }
};

H
hedaoyuan 已提交
307
template <class T>
H
hedaoyuan 已提交
308
class and_op {
W
Wu Yi 已提交
309
 public:
H
hedaoyuan 已提交
310 311 312
  INLINE bool operator()(const T a, const T b) const { return a && b; }
};

H
hedaoyuan 已提交
313
template <class T>
H
hedaoyuan 已提交
314
class or_op {
W
Wu Yi 已提交
315
 public:
H
hedaoyuan 已提交
316 317 318
  INLINE bool operator()(const T a, const T b) const { return a || b; }
};

H
hedaoyuan 已提交
319
template <class T>
H
hedaoyuan 已提交
320
class min {
W
Wu Yi 已提交
321
 public:
H
hedaoyuan 已提交
322 323 324
  INLINE T operator()(const T a, const T b) const { return a > b ? b : a; }
};

H
hedaoyuan 已提交
325
template <class T>
H
hedaoyuan 已提交
326
class max {
W
Wu Yi 已提交
327
 public:
H
hedaoyuan 已提交
328 329 330
  INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
};

331 332 333 334
#ifdef PADDLE_USE_SSE3
#ifndef PADDLE_TYPE_DOUBLE
template <>
class add<__m128> {
W
Wu Yi 已提交
335
 public:
336 337 338 339 340 341 342
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_add_ps(a, b);
  }
};

template <>
class add_scale<__m128> {
W
Wu Yi 已提交
343
 private:
344 345 346
  const __m128 p1;
  const __m128 p2;

W
Wu Yi 已提交
347
 public:
348 349 350 351 352 353 354 355
  INLINE add_scale(const __m128 s1, const __m128 s2) : p1(s1), p2(s2) {}
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_add_ps(_mm_mul_ps(p1, a), _mm_mul_ps(p2, b));
  }
};

template <>
class sub<__m128> {
W
Wu Yi 已提交
356
 public:
357 358 359 360 361 362 363
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_sub_ps(a, b);
  }
};

template <>
class mul<__m128> {
W
Wu Yi 已提交
364
 public:
365 366 367 368 369 370 371
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_mul_ps(a, b);
  }
};

template <>
class div<__m128> {
W
Wu Yi 已提交
372
 public:
373 374 375 376 377 378 379
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_div_ps(a, b);
  }
};

template <>
class min<__m128> {
W
Wu Yi 已提交
380
 public:
381 382 383 384 385 386 387
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_min_ps(a, b);
  }
};

template <>
class max<__m128> {
W
Wu Yi 已提交
388
 public:
389 390 391 392 393 394 395
  INLINE __m128 operator()(const __m128 a, const __m128 b) const {
    return _mm_max_ps(a, b);
  }
};
#else
template <>
class add<__m128d> {
W
Wu Yi 已提交
396
 public:
397 398 399 400 401 402 403
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_add_pd(a, b);
  }
};

template <>
class add_scale<__m128d> {
W
Wu Yi 已提交
404
 private:
405 406 407
  const __m128d p1;
  const __m128d p2;

W
Wu Yi 已提交
408
 public:
409 410 411 412 413 414 415 416
  INLINE add_scale(const __m128d s1, const __m128d s2) : p1(s1), p2(s2) {}
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_add_pd(_mm_mul_pd(p1, a), _mm_mul_pd(p2, b));
  }
};

template <>
class sub<__m128d> {
W
Wu Yi 已提交
417
 public:
418 419 420 421 422 423 424
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_sub_pd(a, b);
  }
};

template <>
class mul<__m128d> {
W
Wu Yi 已提交
425
 public:
426 427 428 429 430 431 432
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_mul_pd(a, b);
  }
};

template <>
class div<__m128d> {
W
Wu Yi 已提交
433
 public:
434 435 436 437 438 439 440
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_div_pd(a, b);
  }
};

template <>
class min<__m128d> {
W
Wu Yi 已提交
441
 public:
442 443 444 445 446 447 448
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_min_pd(a, b);
  }
};

template <>
class max<__m128d> {
W
Wu Yi 已提交
449
 public:
450 451 452 453 454 455 456 457 458 459 460
  INLINE __m128d operator()(const __m128d a, const __m128d b) const {
    return _mm_max_pd(a, b);
  }
};
#endif  // PADDLE_TYPE_DOUBLE
#endif  // PADDLE_USE_SSE3

#ifdef PADDLE_USE_NEON
#ifndef PADDLE_TYPE_DOUBLE
template <>
class add<float32x4_t> {
W
Wu Yi 已提交
461
 public:
462 463
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
H
hedaoyuan 已提交
464
    return vaddq_f32(a, b);
465 466 467 468 469
  }
};

template <>
class add_scale<float32x4_t> {
W
Wu Yi 已提交
470
 private:
471 472 473
  const float32x4_t p1;
  const float32x4_t p2;

W
Wu Yi 已提交
474
 public:
475 476 477 478 479 480 481 482 483 484
  INLINE add_scale(const float32x4_t s1, const float32x4_t s2)
      : p1(s1), p2(s2) {}
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
    return vaddq_f32(vmulq_f32(p1, a), vmulq_f32(p2, b));
  }
};

template <>
class sub<float32x4_t> {
W
Wu Yi 已提交
485
 public:
486 487 488 489 490 491 492 493
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
    return vsubq_f32(a, b);
  }
};

template <>
class mul<float32x4_t> {
W
Wu Yi 已提交
494
 public:
495 496 497 498 499 500 501 502
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
    return vmulq_f32(a, b);
  }
};

template <>
class div<float32x4_t> {
W
Wu Yi 已提交
503
 public:
504 505 506 507 508 509 510 511 512
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
    float32x4_t tmp = vrecpeq_f32(b);
    return vmulq_f32(a, tmp);
  }
};

template <>
class min<float32x4_t> {
W
Wu Yi 已提交
513
 public:
514 515 516 517 518 519 520 521
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
    return vminq_f32(a, b);
  }
};

template <>
class max<float32x4_t> {
W
Wu Yi 已提交
522
 public:
523 524 525 526
  INLINE float32x4_t operator()(const float32x4_t a,
                                const float32x4_t b) const {
    return vmaxq_f32(a, b);
  }
L
Liu Yiqun 已提交
527
};
528 529 530 531 532
#else
#error To be implemented
#endif  // PADDLE_TYPE_DOUBLE
#endif  // PADDLE_USE_NEON

H
hedaoyuan 已提交
533 534 535 536
}  // namespace binary
}  // namespace hppl

#endif  // HL_TENSOR_OPS_H_