MathFunctions.cpp 8.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MathFunctions.h"
#include "hl_matrix_apply.cuh"
Y
Yu Yang 已提交
17
#include "hl_matrix_ops.cuh"
Z
zhangjinchao01 已提交
18 19 20

namespace paddle {

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
template <>
void gemm<float>(const CBLAS_TRANSPOSE transA,
                 const CBLAS_TRANSPOSE transB,
                 const int M,
                 const int N,
                 const int K,
                 const float alpha,
                 const float* A,
                 const int lda,
                 const float* B,
                 const int ldb,
                 const float beta,
                 float* C,
                 const int ldc) {
  cblas_sgemm(CblasRowMajor,
              transA,
              transB,
              M,
              N,
              K,
              alpha,
              A,
              lda,
              B,
              ldb,
              beta,
              C,
              ldc);
}

template <>
void gemm<double>(const CBLAS_TRANSPOSE transA,
                  const CBLAS_TRANSPOSE transB,
                  const int M,
                  const int N,
                  const int K,
                  const double alpha,
                  const double* A,
                  const int lda,
                  const double* B,
                  const int ldb,
                  const double beta,
                  double* C,
                  const int ldc) {
  cblas_dgemm(CblasRowMajor,
              transA,
              transB,
              M,
              N,
              K,
              alpha,
              A,
              lda,
              B,
              ldb,
              beta,
              C,
              ldc);
}

template <>
int getrf<float>(const CBLAS_ORDER order,
                 const int M,
                 const int N,
                 float* A,
                 const int lda,
                 int* ipiv) {
88
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
89 90 91 92 93
#ifdef PADDLE_USE_ATLAS
  return clapack_sgetrf(order, M, N, A, lda, ipiv);
#else
  return LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
#endif
94 95 96 97
#else
  LOG(FATAL) << "Not implemented";
#endif
  return 0;
L
lzhao4ever 已提交
98 99
}

100 101 102 103 104 105 106
template <>
int getrf<double>(const CBLAS_ORDER order,
                  const int M,
                  const int N,
                  double* A,
                  const int lda,
                  int* ipiv) {
107
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
108 109 110 111 112
#ifdef PADDLE_USE_ATLAS
  return clapack_dgetrf(order, M, N, A, lda, ipiv);
#else
  return LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
#endif
113 114 115 116
#else
  LOG(FATAL) << "Not implemented".
#endif
  return 0;
L
lzhao4ever 已提交
117 118
}

119 120 121 122 123 124
template <>
int getri<float>(const CBLAS_ORDER order,
                 const int N,
                 float* A,
                 const int lda,
                 const int* ipiv) {
125
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
126 127 128 129 130
#ifdef PADDLE_USE_ATLAS
  return clapack_sgetri(order, N, A, lda, ipiv);
#else
  return LAPACKE_sgetri(order, N, A, lda, ipiv);
#endif
131 132 133 134
#else
  LOG(FATAL) << "Not implemented".
#endif
  return 0;
L
lzhao4ever 已提交
135 136
}

137 138 139 140 141 142
template <>
int getri<double>(const CBLAS_ORDER order,
                  const int N,
                  double* A,
                  const int lda,
                  const int* ipiv) {
143
#ifdef PADDLE_USE_LAPACK
L
lzhao4ever 已提交
144 145 146 147 148
#ifdef PADDLE_USE_ATLAS
  return clapack_dgetri(order, N, A, lda, ipiv);
#else
  return LAPACKE_dgetri(order, N, A, lda, ipiv);
#endif
149 150 151 152
#else
  LOG(FATAL) << "Not implemented".
#endif
  return 0;
L
lzhao4ever 已提交
153 154
}

155
template <>
Z
zhangjinchao01 已提交
156 157 158 159
void axpy<float>(const int n, const float alpha, const float* x, float* y) {
  cblas_saxpy(n, alpha, x, 1, y, 1);
}

160
template <>
Z
zhangjinchao01 已提交
161 162 163 164
void axpy<double>(const int n, const double alpha, const double* x, double* y) {
  cblas_daxpy(n, alpha, x, 1, y, 1);
}

165
template <>
Z
zhangjinchao01 已提交
166 167 168 169
float dotProduct<float>(const int n, const float* x, const float* y) {
  return cblas_sdot(n, x, 1, y, 1);
}

170
template <>
Z
zhangjinchao01 已提交
171 172 173 174 175 176
double dotProduct<double>(const int n, const double* x, const double* y) {
  return cblas_ddot(n, x, 1, y, 1);
}

#ifdef PADDLE_USE_MKL

177
template <>
Z
zhangjinchao01 已提交
178 179 180 181
void vExp<float>(const int n, const float* a, float* r) {
  vsExp(n, a, r);
}

182
template <>
Z
zhangjinchao01 已提交
183 184 185 186
void vExp<double>(const int n, const double* a, double* r) {
  vdExp(n, a, r);
}

187
template <>
Z
zhangjinchao01 已提交
188 189 190 191
void vPow<float>(const int n, const float* a, const float b, float* r) {
  vsPowx(n, a, b, r);
}

192
template <>
Z
zhangjinchao01 已提交
193 194 195 196
void vPow<double>(const int n, const double* a, const double b, double* r) {
  vdPowx(n, a, b, r);
}

197
template <>
Z
zhangjinchao01 已提交
198 199 200 201
void vLog<float>(const int n, const float* a, float* r) {
  vsLn(n, a, r);
}

202
template <>
Z
zhangjinchao01 已提交
203 204 205 206
void vLog<double>(const int n, const double* a, double* r) {
  vdLn(n, a, r);
}

207
template <>
Z
zhangjinchao01 已提交
208 209 210 211
void vAdd<float>(const int n, const float* a, const float* b, float* r) {
  vsAdd(n, a, b, r);
}

212
template <>
Z
zhangjinchao01 已提交
213 214 215 216
void vAdd<double>(const int n, const double* a, const double* b, double* r) {
  vdAdd(n, a, b, r);
}

217
template <>
Z
zhangjinchao01 已提交
218 219 220 221
void vInvSqrt<float>(const int n, const float* a, float* r) {
  vsInvSqrt(n, a, r);
}

222
template <>
Z
zhangjinchao01 已提交
223 224 225 226
void vInvSqrt<double>(const int n, const double* a, double* r) {
  vdInvSqrt(n, a, r);
}

227
template <>
Z
zhangjinchao01 已提交
228 229 230 231
void vLog1p<float>(const int n, const float* a, float* r) {
  vsLog1p(n, a, r);
}

232
template <>
Z
zhangjinchao01 已提交
233 234 235 236
void vLog1p<double>(const int n, const double* a, double* r) {
  vdLog1p(n, a, r);
}

237
template <>
Z
zhangjinchao01 已提交
238 239 240 241
void vTanh<float>(const int n, const float* a, float* r) {
  vsTanh(n, a, r);
}

242
template <>
Z
zhangjinchao01 已提交
243 244 245 246 247 248
void vTanh<double>(const int n, const double* a, double* r) {
  vdTanh(n, a, r);
}
#else

DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a));
249
template <class T>
Z
zhangjinchao01 已提交
250 251
void vExp(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vExp<T>, 0, 0>(
252
      binary::vExp<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
253 254 255
}

DEFINE_MATRIX_BINARY_OP(vLog, b = std::log(a));
256
template <class T>
Z
zhangjinchao01 已提交
257 258
void vLog(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vLog<T>, 0, 0>(
259
      binary::vLog<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
260 261 262
}

DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
263
template <class T>
Z
zhangjinchao01 已提交
264 265
void vInvSqrt(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vInvSqrt<T>, 0, 0>(
266
      binary::vInvSqrt<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
267 268 269
}

DEFINE_MATRIX_BINARY_OP(vLog1p, b = std::log(1.0f + a));
270
template <class T>
Z
zhangjinchao01 已提交
271 272
void vLog1p(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vLog1p<T>, 0, 0>(
273
      binary::vLog1p<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
274 275
}

276 277 278 279
DEFINE_MATRIX_BINARY_OP(vTanh, T tmp = -2.0 * a;
                        tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
                        b = 2.0 / (1.0 + std::exp(tmp)) - 1.0);
template <class T>
Z
zhangjinchao01 已提交
280 281
void vTanh(const int n, const T* a, T* r) {
  hl_cpu_apply_binary_op<T, binary::vTanh<T>, 0, 0>(
282
      binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
283 284 285
}

DEFINE_MATRIX_BINARY_PARAMETER_OP(vPow, ONE_PARAMETER, b = std::pow(a, p));
286
template <class T>
Z
zhangjinchao01 已提交
287 288
void vPow(const int n, const T* a, const T b, T* r) {
  hl_cpu_apply_binary_op<T, binary::vPow<T>, 0, 0>(
289
      binary::vPow<T>(b), const_cast<T*>(a), r, 1, n, n, n);
Z
zhangjinchao01 已提交
290 291 292
}

DEFINE_MATRIX_TERNARY_OP(vAdd, c = a + b);
293
template <class T>
Z
zhangjinchao01 已提交
294 295
void vAdd(const int n, const T* a, const T* b, T* r) {
  hl_cpu_apply_ternary_op<T, ternary::vAdd<T>, 0, 0>(ternary::vAdd<T>(),
296 297 298 299 300 301 302 303
                                                     const_cast<T*>(a),
                                                     const_cast<T*>(b),
                                                     r,
                                                     1,
                                                     n,
                                                     n,
                                                     n,
                                                     n);
Z
zhangjinchao01 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
}

template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r);
template void vLog(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const float* a, float* r);
template void vLog1p(const int n, const float* a, float* r);
template void vLog1p(const int n, const double* a, double* r);
template void vTanh(const int n, const float* a, float* r);
template void vTanh(const int n, const double* a, double* r);
template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r);

#endif

}  // namespace paddle