math_function.cu 6.1 KB
Newer Older
Q
qijun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Q
qijun 已提交
15 16 17 18
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
Q
qijun 已提交
19
#include "paddle/operators/math/math_function.h"
Q
qijun 已提交
20

Q
qijun 已提交
21 22 23 24 25
namespace paddle {
namespace operators {
namespace math {

template <>
Q
qijun 已提交
26 27 28 29 30 31
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
                                     const CBLAS_TRANSPOSE transB, const int M,
                                     const int N, const int K,
                                     const float alpha, const float* A,
                                     const float* B, const float beta, float* C,
                                     platform::DeviceContext* context) {
Q
qijun 已提交
32 33
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
34 35
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
36
  cublasOperation_t cuTransA =
Q
qijun 已提交
37
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
38
  cublasOperation_t cuTransB =
Q
qijun 已提交
39
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
40

Q
qijun 已提交
41
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
qijun 已提交
42
      reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
Q
qijun 已提交
43
      cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
Q
qijun 已提交
44 45 46
}

template <>
Q
qijun 已提交
47 48 49 50 51 52 53
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
                                      const CBLAS_TRANSPOSE transB, const int M,
                                      const int N, const int K,
                                      const double alpha, const double* A,
                                      const double* B, const double beta,
                                      double* C,
                                      platform::DeviceContext* context) {
Q
qijun 已提交
54 55
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
56 57
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
58
  cublasOperation_t cuTransA =
Q
qijun 已提交
59
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
60
  cublasOperation_t cuTransB =
Q
qijun 已提交
61
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
62
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
qijun 已提交
63
      reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
Q
qijun 已提交
64
      cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
Q
qijun 已提交
65 66
}

Q
qijun 已提交
67
template <>
Q
qijun 已提交
68 69 70 71 72
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
                                       bool trans_a,
                                       const framework::Tensor& matrix_b,
                                       bool trans_b, float alpha,
                                       framework::Tensor* matrix_out,
Q
qijun 已提交
73 74
                                       float beta,
                                       platform::DeviceContext* context) {
Q
qijun 已提交
75 76 77 78 79 80 81 82 83
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
Q
qijun 已提交
84
                 "Matrix must all be in GPUPlace");
Q
qijun 已提交
85

Q
qijun 已提交
86 87 88
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
89

Q
qijun 已提交
90 91
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
92

Q
qijun 已提交
93 94 95
  gemm<platform::GPUPlace, float>(
      transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
Q
qijun 已提交
96 97 98
}

template <>
Q
qijun 已提交
99 100 101 102 103 104
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
                                        bool trans_a,
                                        const framework::Tensor& matrix_b,
                                        bool trans_b, double alpha,
                                        framework::Tensor* matrix_out,
                                        double beta,
Q
qijun 已提交
105
                                        platform::DeviceContext* context) {
Q
qijun 已提交
106 107 108 109 110 111 112 113 114
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
Q
qijun 已提交
115
                 "Matrix must all be in GPUPlace");
Q
qijun 已提交
116

Q
qijun 已提交
117 118 119 120 121 122
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
123

Q
qijun 已提交
124 125 126
  gemm<platform::GPUPlace, double>(
      transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
Q
qijun 已提交
127
}
Q
qijun 已提交
128

Q
qijun 已提交
129
template <>
Q
qijun 已提交
130 131 132
void Set<platform::GPUPlace, float>(const int n, const float alpha,
                                    float* output,
                                    platform::DeviceContext* context) {
Q
qijun 已提交
133
  auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
Q
qijun 已提交
134 135
  framework::EigenVector<float>::Type out(output, n);
  out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha));
Q
qijun 已提交
136 137
}

Q
qijun 已提交
138 139 140
}  // namespace math
}  // namespace operators
}  // namespace paddle