math_function.cc 12.4 KB
Newer Older
Q
qijun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/math_function.h"
16
#include "paddle/framework/data_type.h"
17
#include "paddle/operators/math/math_function_impl.h"
Q
qijun 已提交
18 19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

template <>
Q
QI JUN 已提交
24 25 26 27 28
void gemm<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
D
dongzhihong 已提交
29 30
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
31
  int ldc = N;
Q
qijun 已提交
32 33
  cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
              beta, C, ldc);
Q
qijun 已提交
34 35 36
}

template <>
Q
QI JUN 已提交
37 38 39 40 41
void gemm<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
D
dongzhihong 已提交
42 43
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
44
  int ldc = N;
Q
qijun 已提交
45 46
  cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
              beta, C, ldc);
Q
qijun 已提交
47 48
}

G
guosheng 已提交
49
template <>
Q
QI JUN 已提交
50 51 52 53 54
void gemm<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K, const float alpha,
    const float* A, const int lda, const float* B, const int ldb,
    const float beta, float* C, const int ldc) {
G
guosheng 已提交
55 56 57 58 59 60
  cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
              transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
              lda, B, ldb, beta, C, ldc);
}

template <>
Q
QI JUN 已提交
61 62 63 64 65
void gemm<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const double alpha, const double* A, const int lda, const double* B,
    const int ldb, const double beta, double* C, const int ldc) {
G
guosheng 已提交
66 67 68 69 70
  cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
              transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
              lda, B, ldb, beta, C, ldc);
}

Q
qijun 已提交
71
template <>
Q
QI JUN 已提交
72 73 74 75
void matmul<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float alpha,
76
    framework::Tensor* matrix_out, float beta) {
Q
qijun 已提交
77 78 79 80 81 82 83 84 85
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
                     platform::is_cpu_place(matrix_b.place()) &&
                     platform::is_cpu_place(matrix_out->place()),
Q
qijun 已提交
86 87
                 "Matrix must all be in CPUPlace");

Q
qijun 已提交
88 89 90
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
91

Q
qijun 已提交
92 93
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
94

Q
QI JUN 已提交
95
  gemm<platform::CPUDeviceContext, float>(
96 97
      context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>());
Q
qijun 已提交
98 99 100
}

template <>
Q
QI JUN 已提交
101 102 103 104
void matmul<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, double alpha,
105
    framework::Tensor* matrix_out, double beta) {
Q
qijun 已提交
106 107 108 109 110 111 112 113 114
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
                     platform::is_cpu_place(matrix_b.place()) &&
                     platform::is_cpu_place(matrix_out->place()),
Q
qijun 已提交
115 116
                 "Matrix must all be in CPUPlace");

Q
qijun 已提交
117 118 119 120 121 122
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
123

Q
QI JUN 已提交
124
  gemm<platform::CPUDeviceContext, double>(
125 126
      context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>());
Q
qijun 已提交
127 128
}

T
tensor-tang 已提交
129
#ifdef PADDLE_WITH_MKLML
M
Markus Kliegl 已提交
130 131
// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize.
template <>
Q
QI JUN 已提交
132 133
void batched_gemm<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C, const int batchCount, const int strideA, const int strideB) {
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  auto a_array = std::vector<const float*>(batchCount);
  auto b_array = std::vector<const float*>(batchCount);
  auto c_array = std::vector<float*>(batchCount);
  for (int k = 0; k < batchCount; ++k) {
    a_array[k] = &A[k * strideA];
    b_array[k] = &B[k * strideB];
    c_array[k] = &C[k * M * N];
  }
  cblas_sgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
                    a_array.data(), &lda, b_array.data(), &ldb, &beta,
                    c_array.data(), &ldc, 1 /* group_count */, &batchCount);
}

template <>
Q
QI JUN 已提交
154 155
void batched_gemm<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C, const int batchCount, const int strideA, const int strideB) {
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  auto a_array = std::vector<const double*>(batchCount);
  auto b_array = std::vector<const double*>(batchCount);
  auto c_array = std::vector<double*>(batchCount);
  for (int k = 0; k < batchCount; ++k) {
    a_array[k] = &A[k * strideA];
    b_array[k] = &B[k * strideB];
    c_array[k] = &C[k * M * N];
  }
  cblas_dgemm_batch(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha,
                    a_array.data(), &lda, b_array.data(), &ldb, &beta,
                    c_array.data(), &ldc, 1 /* group_count */, &batchCount);
}
#else
// The below is a naive but correct serial implementation that just loops
// over the batch dimension. This is a fallback for when the batched gemm
// functions of Intel MKL are not available. In the future, this computation
// should be parallelized.
template <>
Q
QI JUN 已提交
180 181
void batched_gemm<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
182 183 184 185 186 187 188
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C, const int batchCount, const int strideA, const int strideB) {
  for (int k = 0; k < batchCount; ++k) {
    const float* Ak = &A[k * strideA];
    const float* Bk = &B[k * strideB];
    float* Ck = &C[k * M * N];
Q
QI JUN 已提交
189 190
    gemm<platform::CPUDeviceContext, float>(context, transA, transB, M, N, K,
                                            alpha, Ak, Bk, beta, Ck);
M
Markus Kliegl 已提交
191 192 193 194
  }
}

template <>
Q
QI JUN 已提交
195 196
void batched_gemm<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
197 198 199 200 201 202 203
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C, const int batchCount, const int strideA, const int strideB) {
  for (int k = 0; k < batchCount; ++k) {
    const double* Ak = &A[k * strideA];
    const double* Bk = &B[k * strideB];
    double* Ck = &C[k * M * N];
Q
QI JUN 已提交
204 205
    gemm<platform::CPUDeviceContext, double>(context, transA, transB, M, N, K,
                                             alpha, Ak, Bk, beta, Ck);
M
Markus Kliegl 已提交
206 207 208 209
  }
}
#endif

210
template <>
Q
QI JUN 已提交
211 212 213 214
void gemv<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context, const bool trans_a, const int M,
    const int N, const float alpha, const float* A, const float* B,
    const float beta, float* C) {
215 216 217 218 219
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  cblas_sgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}

template <>
Q
QI JUN 已提交
220 221 222 223
void gemv<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context, const bool trans_a, const int M,
    const int N, const double alpha, const double* A, const double* B,
    const double beta, double* C) {
224 225 226 227
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}

228
template <>
Q
QI JUN 已提交
229 230 231
void axpy<platform::CPUDeviceContext, float>(
    const platform::CPUDeviceContext& context, const int n, const float alpha,
    const float* x, float* y) {
232 233 234 235
  cblas_saxpy(n, alpha, x, 1, y, 1);
}

template <>
Q
QI JUN 已提交
236 237 238
void axpy<platform::CPUDeviceContext, double>(
    const platform::CPUDeviceContext& context, const int n, const double alpha,
    const double* x, double* y) {
239 240 241
  cblas_daxpy(n, alpha, x, 1, y, 1);
}

Q
QI JUN 已提交
242 243 244 245 246
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
247

Q
QI JUN 已提交
248 249 250
#define DEFINE_CPU_TRANS(RANK)                                        \
  template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
  template struct Transpose<platform::CPUDeviceContext, double, RANK>;
251 252 253 254 255 256 257

DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
DEFINE_CPU_TRANS(3);
DEFINE_CPU_TRANS(4);
DEFINE_CPU_TRANS(5);
DEFINE_CPU_TRANS(6);
Q
qijun 已提交
258

259 260
struct TensorSetConstantCPU {
  TensorSetConstantCPU(framework::Tensor* tensor, float value)
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
      : tensor_(tensor), value_(value) {}
  template <typename T>
  void operator()() const {
    auto cpu = platform::CPUPlace();
    auto* begin = tensor_->mutable_data<T>(cpu);
    std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
  }
  framework::Tensor* tensor_;
  float value_;
};

template <>
void set_constant_with_place<platform::CPUPlace>(
    const platform::DeviceContext& context, framework::Tensor* tensor,
    float value) {
  framework::VisitDataType(framework::ToDataType(tensor->type()),
277
                           TensorSetConstantCPU(tensor, value));
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
}

struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
  TensorSetConstantWithPlace(const platform::DeviceContext& context,
                             framework::Tensor* tensor, float value)
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename Place>
  void operator()(Place place) const {
    set_constant_with_place<Place>(context_, tensor_, value_);
  }

  const platform::DeviceContext& context_;
  framework::Tensor* tensor_;
  float value_;
};

void set_constant(const platform::DeviceContext& context,
                  framework::Tensor* tensor, float value) {
Y
Fix CI  
Yu Yang 已提交
297
  TensorSetConstantWithPlace func(context, tensor, value);
298
#ifdef PADDLE_WITH_CUDA
Y
Fix CI  
Yu Yang 已提交
299
  tensor->place().apply_visitor(func);
300 301 302 303 304
#else
  func(platform::CPUPlace());
#endif
}

Q
QI JUN 已提交
305 306 307 308
template struct RowwiseAdd<platform::CPUDeviceContext, float>;
template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUDeviceContext, double>;
309

Q
qijun 已提交
310 311 312
}  // namespace math
}  // namespace operators
}  // namespace paddle