math_function.cc 4.4 KB
Newer Older
Q
qijun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/math_function.h"

namespace paddle {
namespace operators {
namespace math {

template <>
22 23
void gemm<platform::CPUPlace, float>(const platform::DeviceContext& context,
                                     const CBLAS_TRANSPOSE transA,
Q
qijun 已提交
24 25 26
                                     const CBLAS_TRANSPOSE transB, const int M,
                                     const int N, const int K,
                                     const float alpha, const float* A,
27 28
                                     const float* B, const float beta,
                                     float* C) {
D
dongzhihong 已提交
29 30
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
31
  int ldc = N;
Q
qijun 已提交
32 33
  cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
              beta, C, ldc);
Q
qijun 已提交
34 35 36
}

template <>
37 38
void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
                                      const CBLAS_TRANSPOSE transA,
Q
qijun 已提交
39 40 41 42
                                      const CBLAS_TRANSPOSE transB, const int M,
                                      const int N, const int K,
                                      const double alpha, const double* A,
                                      const double* B, const double beta,
43
                                      double* C) {
D
dongzhihong 已提交
44 45
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
46
  int ldc = N;
Q
qijun 已提交
47 48
  cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
              beta, C, ldc);
Q
qijun 已提交
49 50
}

Q
qijun 已提交
51
template <>
52 53 54 55
void matmul<platform::CPUPlace, float>(
    const platform::DeviceContext& context, const framework::Tensor& matrix_a,
    bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha,
    framework::Tensor* matrix_out, float beta) {
Q
qijun 已提交
56 57 58 59 60 61 62 63 64
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
                     platform::is_cpu_place(matrix_b.place()) &&
                     platform::is_cpu_place(matrix_out->place()),
Q
qijun 已提交
65 66
                 "Matrix must all be in CPUPlace");

Q
qijun 已提交
67 68 69
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
70

Q
qijun 已提交
71 72
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
73

Q
qijun 已提交
74
  gemm<platform::CPUPlace, float>(
75 76
      context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>());
Q
qijun 已提交
77 78 79
}

template <>
80 81 82 83
void matmul<platform::CPUPlace, double>(
    const platform::DeviceContext& context, const framework::Tensor& matrix_a,
    bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha,
    framework::Tensor* matrix_out, double beta) {
Q
qijun 已提交
84 85 86 87 88 89 90 91 92
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
                     platform::is_cpu_place(matrix_b.place()) &&
                     platform::is_cpu_place(matrix_out->place()),
Q
qijun 已提交
93 94
                 "Matrix must all be in CPUPlace");

Q
qijun 已提交
95 96 97 98 99 100
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
101

Q
qijun 已提交
102
  gemm<platform::CPUPlace, double>(
103 104
      context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>());
Q
qijun 已提交
105 106
}

Q
qijun 已提交
107 108 109
}  // namespace math
}  // namespace operators
}  // namespace paddle