MathFunctions.cpp 10.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MathFunctions.h"
#include "hl_matrix_apply.cuh"
Y
Yu Yang 已提交
17
#include "hl_matrix_ops.cuh"
L
liaogang 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "paddle/utils/DynamicLoad.h"

namespace dynload {

std::once_flag lapack_dso_flag;
void* lapack_dso_handle = nullptr;

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load lapack routine
 * via operator overloading.
 *
 * note: default dynamic linked libs
 */
#define DYNAMIC_LOAD_LAPACK_WRAP(__name)                                       \
  struct DynLoad__##__name {                                                   \
    template <typename... Args>                                                \
L
liaogang 已提交
35
    auto operator()(Args... args)->decltype(__name(args...)) {                 \
L
liaogang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49
      using lapack_func = decltype(__name(args...)) (*)(Args...);              \
      std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \
      void* p_##__name = dlsym(lapack_dso_handle, #__name);                    \
      return reinterpret_cast<lapack_func>(p_##__name)(args...);               \
    }                                                                          \
  } __name;  // struct DynLoad__##__name

// clang-format off
#ifdef PADDLE_USE_LAPACK
#ifdef PADDLE_USE_ATLAS
  #define LAPACK_ROUTINE_EACH(__macro)        \
    __macro(clapack_sgetrf)                   \
    __macro(clapack_dgetrf)                   \
    __macro(clapack_sgetri)                   \
L
liaogang 已提交
50
    __macro(clapack_dgetri)
L
liaogang 已提交
51 52 53 54 55
#else
  #define LAPACK_ROUTINE_EACH(__macro)        \
    __macro(LAPACKE_sgetrf)                   \
    __macro(LAPACKE_dgetrf)                   \
    __macro(LAPACKE_sgetri)                   \
L
liaogang 已提交
56
    __macro(LAPACKE_dgetri)
L
liaogang 已提交
57
#endif
L
liaogang 已提交
58
LAPACK_ROUTINE_EACH(DYNAMIC_LOAD_LAPACK_WRAP)
L
liaogang 已提交
59
#endif
L
liaogang 已提交
60

L
liaogang 已提交
61 62
// clang-format on
}  // namespace dynload
Z
zhangjinchao01 已提交
63 64 65

namespace paddle {

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
template <>
void gemm<float>(const CBLAS_TRANSPOSE transA,
                 const CBLAS_TRANSPOSE transB,
                 const int M,
                 const int N,
                 const int K,
                 const float alpha,
                 const float* A,
                 const int lda,
                 const float* B,
                 const int ldb,
                 const float beta,
                 float* C,
                 const int ldc) {
  cblas_sgemm(CblasRowMajor,
              transA,
              transB,
              M,
              N,
              K,
              alpha,
              A,
              lda,
              B,
              ldb,
              beta,
              C,
              ldc);
}

template <>
void gemm<double>(const CBLAS_TRANSPOSE transA,
                  const CBLAS_TRANSPOSE transB,
                  const int M,
                  const int N,
                  const int K,
                  const double alpha,
                  const double* A,
                  const int lda,
                  const double* B,
                  const int ldb,
                  const double beta,
                  double* C,
                  const int ldc) {
  cblas_dgemm(CblasRowMajor,
              transA,
              transB,
              M,
              N,
              K,
              alpha,
              A,
              lda,
              B,
              ldb,
              beta,
              C,
              ldc);
}

template <>
int getrf<float>(const CBLAS_ORDER order,
                 const int M,
                 const int N,
                 float* A,
                 const int lda,
                 int* ipiv) {
133
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
134
#ifdef PADDLE_USE_ATLAS
L
liaogang 已提交
135
  return dynload::clapack_sgetrf(order, M, N, A, lda, ipiv);
L
lzhao4ever 已提交
136
#else
L
liaogang 已提交
137
  return dynload::LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
L
lzhao4ever 已提交
138
#endif
139 140 141 142
#else
  LOG(FATAL) << "Not implemented";
#endif
  return 0;
L
lzhao4ever 已提交
143 144
}

145 146 147 148 149 150 151
template <>
int getrf<double>(const CBLAS_ORDER order,
                  const int M,
                  const int N,
                  double* A,
                  const int lda,
                  int* ipiv) {
152
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
153
#ifdef PADDLE_USE_ATLAS
L
liaogang 已提交
154
  return dynload::clapack_dgetrf(order, M, N, A, lda, ipiv);
L
lzhao4ever 已提交
155
#else
L
liaogang 已提交
156
  return dynload::LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
L
lzhao4ever 已提交
157
#endif
158
#else
L
Liu Yiqun 已提交
159
  LOG(FATAL) << "Not implemented";
160 161
#endif
  return 0;
L
lzhao4ever 已提交
162 163
}

164 165 166 167 168 169
template <>
int getri<float>(const CBLAS_ORDER order,
                 const int N,
                 float* A,
                 const int lda,
                 const int* ipiv) {
170
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
171
#ifdef PADDLE_USE_ATLAS
L
liaogang 已提交
172
  return dynload::clapack_sgetri(order, N, A, lda, ipiv);
L
lzhao4ever 已提交
173
#else
L
liaogang 已提交
174
  return dynload::LAPACKE_sgetri(order, N, A, lda, ipiv);
L
lzhao4ever 已提交
175
#endif
176
#else
L
Liu Yiqun 已提交
177
  LOG(FATAL) << "Not implemented";
178 179
#endif
  return 0;
L
lzhao4ever 已提交
180 181
}

182 183 184 185 186 187
template <>
int getri<double>(const CBLAS_ORDER order,
                  const int N,
                  double* A,
                  const int lda,
                  const int* ipiv) {
188
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
189
#ifdef PADDLE_USE_ATLAS
L
liaogang 已提交
190
  return dynload::clapack_dgetri(order, N, A, lda, ipiv);
L
lzhao4ever 已提交
191
#else
L
liaogang 已提交
192
  return dynload::LAPACKE_dgetri(order, N, A, lda, ipiv);
L
lzhao4ever 已提交
193
#endif
194
#else
L
Liu Yiqun 已提交
195
  LOG(FATAL) << "Not implemented";
196 197
#endif
  return 0;
L
lzhao4ever 已提交
198 199
}

200
template <>
Z
zhangjinchao01 已提交
201 202 203 204
void axpy<float>(const int n, const float alpha, const float* x, float* y) {
  cblas_saxpy(n, alpha, x, 1, y, 1);
}

205
template <>
Z
zhangjinchao01 已提交
206 207 208 209
void axpy<double>(const int n, const double alpha, const double* x, double* y) {
  cblas_daxpy(n, alpha, x, 1, y, 1);
}

210
template <>
Z
zhangjinchao01 已提交
211 212 213 214
float dotProduct<float>(const int n, const float* x, const float* y) {
  return cblas_sdot(n, x, 1, y, 1);
}

215
template <>
Z
zhangjinchao01 已提交
216 217 218 219 220 221
double dotProduct<double>(const int n, const double* x, const double* y) {
  return cblas_ddot(n, x, 1, y, 1);
}

#ifdef PADDLE_USE_MKL

222
template <>
Z
zhangjinchao01 已提交
223 224 225 226
void vExp<float>(const int n, const float* a, float* r) {
  vsExp(n, a, r);
}

227
template <>
Z
zhangjinchao01 已提交
228 229 230 231
void vExp<double>(const int n, const double* a, double* r) {
  vdExp(n, a, r);
}

232
template <>
Z
zhangjinchao01 已提交
233 234 235 236
void vPow<float>(const int n, const float* a, const float b, float* r) {
  vsPowx(n, a, b, r);
}

237
template <>
Z
zhangjinchao01 已提交
238 239 240 241
void vPow<double>(const int n, const double* a, const double b, double* r) {
  vdPowx(n, a, b, r);
}

242
template <>
Z
zhangjinchao01 已提交
243 244 245 246
void vLog<float>(const int n, const float* a, float* r) {
  vsLn(n, a, r);
}

247
template <>
Z
zhangjinchao01 已提交
248 249 250 251
void vLog<double>(const int n, const double* a, double* r) {
  vdLn(n, a, r);
}

252
template <>
Z
zhangjinchao01 已提交
253 254 255 256
void vAdd<float>(const int n, const float* a, const float* b, float* r) {
  vsAdd(n, a, b, r);
}

257
template <>
Z
zhangjinchao01 已提交
258 259 260 261
void vAdd<double>(const int n, const double* a, const double* b, double* r) {
  vdAdd(n, a, b, r);
}

262
template <>
Z
zhangjinchao01 已提交
263 264 265 266
void vInvSqrt<float>(const int n, const float* a, float* r) {
  vsInvSqrt(n, a, r);
}

267
template <>
Z
zhangjinchao01 已提交
268 269 270 271
void vInvSqrt<double>(const int n, const double* a, double* r) {
  vdInvSqrt(n, a, r);
}

272
template <>
Z
zhangjinchao01 已提交
273 274 275 276
void vLog1p<float>(const int n, const float* a, float* r) {
  vsLog1p(n, a, r);
}

277
template <>
Z
zhangjinchao01 已提交
278 279 280 281
void vLog1p<double>(const int n, const double* a, double* r) {
  vdLog1p(n, a, r);
}

282
template <>
Z
zhangjinchao01 已提交
283 284 285 286
void vTanh<float>(const int n, const float* a, float* r) {
  vsTanh(n, a, r);
}

287
template <>
Z
zhangjinchao01 已提交
288 289 290 291 292 293
void vTanh<double>(const int n, const double* a, double* r) {
  vdTanh(n, a, r);
}
#else

DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a));
294
template <class T>
Z
zhangjinchao01 已提交
295 296
void vExp(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vExp<T>, 0, 0>(
297
      binary::vExp<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
298 299 300
}

DEFINE_MATRIX_BINARY_OP(vLog, b = std::log(a));
301
template <class T>
Z
zhangjinchao01 已提交
302 303
void vLog(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vLog<T>, 0, 0>(
304
      binary::vLog<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
305 306 307
}

DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
308
template <class T>
Z
zhangjinchao01 已提交
309 310
void vInvSqrt(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vInvSqrt<T>, 0, 0>(
311
      binary::vInvSqrt<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
312 313 314
}

DEFINE_MATRIX_BINARY_OP(vLog1p, b = std::log(1.0f + a));
315
template <class T>
Z
zhangjinchao01 已提交
316 317
void vLog1p(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vLog1p<T>, 0, 0>(
318
      binary::vLog1p<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
319 320
}

321 322 323 324
DEFINE_MATRIX_BINARY_OP(vTanh, T tmp = -2.0 * a;
                        tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
                        b = 2.0 / (1.0 + std::exp(tmp)) - 1.0);
template <class T>
Z
zhangjinchao01 已提交
325 326
void vTanh(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vTanh<T>, 0, 0>(
327
      binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
328 329 330
}

DEFINE_MATRIX_BINARY_PARAMETER_OP(vPow, ONE_PARAMETER, b = std::pow(a, p));
331
template <class T>
Z
zhangjinchao01 已提交
332 333
void vPow(const int n, const T* a, const T b, T* r) {
  hl_cpu_apply_binary_op<T, binary::vPow<T>, 0, 0>(
334
      binary::vPow<T>(b), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
335 336 337
}

DEFINE_MATRIX_TERNARY_OP(vAdd, c = a + b);
338
template <class T>
Z
zhangjinchao01 已提交
339 340
void vAdd(const int n, const T* a, const T* b, T* r) {
  hl_cpu_apply_ternary_op<T, ternary::vAdd<T>, 0, 0>(ternary::vAdd<T>(),
341 342 343 344 345 346 347 348
                                                     const_cast<T*>(a),
                                                     const_cast<T*>(b),
                                                     r,
                                                     1,
                                                     n,
                                                     n,
                                                     n,
                                                     n);
Z
zhangjinchao01 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
}

template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r);
template void vLog(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const float* a, float* r);
template void vLog1p(const int n, const float* a, float* r);
template void vLog1p(const int n, const double* a, double* r);
template void vTanh(const int n, const float* a, float* r);
template void vTanh(const int n, const double* a, double* r);
template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r);

#endif

}  // namespace paddle