batched_gemm.cc 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/cuda/math/batched_gemm.h"
#include <iostream>
#include "lite/core/device_info.h"

namespace paddle {
namespace lite {
namespace cuda {
namespace math {

24 25 26 27 28
template <typename PtypeIn, typename PtypeOut>
bool BatchedGemm<PtypeIn, PtypeOut>::init(const bool trans_a,
                                          const bool trans_b,
                                          const int max_batch_size,
                                          Context<TARGET(kCUDA)> *ctx) {
29 30 31 32 33 34 35
  if (cu_handle_ == nullptr) {
    this->exe_stream_ = ctx->exec_stream();
    CUBLAS_CALL(cublasCreate(&cu_handle_));
    CUBLAS_CALL(cublasSetStream(cu_handle_, this->exe_stream_));
  }
  cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
  cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
36 37 38
  if (A_ != nullptr) {
    cudaFree(A_);
  }
39
  cudaMalloc(reinterpret_cast<void **>(&A_),
40
             3 * max_batch_size * sizeof(PtypeIn *));
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  return true;
}

template <>
bool BatchedGemm<float, float>::run(const float alpha,
                                    const float beta,
                                    const float *a[],
                                    const float *b[],
                                    float *c[],
                                    const int m,
                                    const int n,
                                    const int k,
                                    const int batch_size) {
  CHECK(a != nullptr);
  CHECK(b != nullptr);
  CHECK(c != nullptr);
  lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
  ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
  ldc_ = n;
  m_ = m;
  n_ = n;
  k_ = k;
  cudaMemcpyAsync(A_,
                  a,
                  batch_size * sizeof(const float *),
                  cudaMemcpyHostToDevice,
                  exe_stream_);
  cudaMemcpyAsync(A_ + batch_size,
                  b,
                  batch_size * sizeof(const float *),
                  cudaMemcpyHostToDevice,
                  exe_stream_);
  cudaMemcpyAsync(A_ + batch_size * 2,
                  c,
                  batch_size * sizeof(float *),
                  cudaMemcpyHostToDevice,
                  exe_stream_);
  CUBLAS_CALL(cublasSgemmBatched(cu_handle_,
                                 cu_trans_b_,
                                 cu_trans_a_,
                                 n_,
                                 m_,
                                 k_,
                                 &alpha,
                                 const_cast<const float **>(A_ + batch_size),
                                 ldb_,
                                 const_cast<const float **>(A_),
                                 lda_,
                                 &beta,
                                 A_ + batch_size * 2,
                                 ldc_,
                                 batch_size));
  return true;
}

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
template <>
bool BatchedGemm<half, half>::run(const half alpha,
                                  const half beta,
                                  const half *a[],
                                  const half *b[],
                                  half *c[],
                                  const int m,
                                  const int n,
                                  const int k,
                                  const int batch_size) {
  CHECK(a != nullptr);
  CHECK(b != nullptr);
  CHECK(c != nullptr);
  lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
  ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
  ldc_ = n;
  m_ = m;
  n_ = n;
  k_ = k;
  cudaMemcpyAsync(A_,
                  a,
                  batch_size * sizeof(const half *),
                  cudaMemcpyHostToDevice,
                  exe_stream_);
  cudaMemcpyAsync(A_ + batch_size,
                  b,
                  batch_size * sizeof(const half *),
                  cudaMemcpyHostToDevice,
                  exe_stream_);
  cudaMemcpyAsync(A_ + batch_size * 2,
                  c,
                  batch_size * sizeof(half *),
                  cudaMemcpyHostToDevice,
                  exe_stream_);
  CUBLAS_CALL(cublasHgemmBatched(cu_handle_,
                                 cu_trans_b_,
                                 cu_trans_a_,
                                 n_,
                                 m_,
                                 k_,
                                 &alpha,
                                 const_cast<const half **>(A_ + batch_size),
                                 ldb_,
                                 const_cast<const half **>(A_),
                                 lda_,
                                 &beta,
                                 A_ + batch_size * 2,
                                 ldc_,
                                 batch_size));
  return true;
}

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
template <>
bool BatchedGemm<float, float>::run(const float alpha,
                                    const float beta,
                                    const float *a[],
                                    const int m,
                                    const int n,
                                    const int k,
                                    const int batch_size) {
  CHECK(a != nullptr);
  lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
  ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
  ldc_ = n;
  m_ = m;
  n_ = n;
  k_ = k;
  cudaMemcpyAsync(A_,
                  a,
                  3 * batch_size * sizeof(const float *),
                  cudaMemcpyDefault,
                  exe_stream_);
  CUBLAS_CALL(cublasSgemmBatched(cu_handle_,
                                 cu_trans_b_,
                                 cu_trans_a_,
                                 n_,
                                 m_,
                                 k_,
                                 &alpha,
                                 const_cast<const float **>(A_ + batch_size),
                                 ldb_,
                                 const_cast<const float **>(A_),
                                 lda_,
                                 &beta,
                                 A_ + batch_size * 2,
                                 ldc_,
                                 batch_size));
  return true;
}

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
template <>
bool BatchedGemm<half, half>::run(const half alpha,
                                  const half beta,
                                  const half *a[],
                                  const int m,
                                  const int n,
                                  const int k,
                                  const int batch_size) {
  CHECK(a != nullptr);
  lda_ = (cu_trans_a_ == CUBLAS_OP_N) ? k : m;
  ldb_ = (cu_trans_b_ == CUBLAS_OP_N) ? n : k;
  ldc_ = n;
  m_ = m;
  n_ = n;
  k_ = k;
  cudaMemcpyAsync(A_,
                  a,
                  3 * batch_size * sizeof(const half *),
                  cudaMemcpyDefault,
                  exe_stream_);
  CUBLAS_CALL(cublasHgemmBatched(cu_handle_,
                                 cu_trans_b_,
                                 cu_trans_a_,
                                 n_,
                                 m_,
                                 k_,
                                 &alpha,
                                 const_cast<const half **>(A_ + batch_size),
                                 ldb_,
                                 const_cast<const half **>(A_),
                                 lda_,
                                 &beta,
                                 A_ + batch_size * 2,
                                 ldc_,
                                 batch_size));
  return true;
}

template class BatchedGemm<float, float>;
template class BatchedGemm<half, half>;

227 228 229 230
}  // namespace math
}  // namespace cuda
}  // namespace lite
}  // namespace paddle