math_function.cu 21.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#define EIGEN_USE_GPU
Y
Yu Yang 已提交
16
#include <vector>
Y
Yi Wang 已提交
17 18 19
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
20
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
21

Q
qijun 已提交
22 23 24 25
namespace paddle {
namespace operators {
namespace math {

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
using float16 = paddle::platform::float16;

template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

K
Kexin Zhao 已提交
43 44
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
45 46
                    "cublas fp16 gemm requires GPU compute capability >= 53");

47 48 49 50
#if CUDA_VERSION >= 8000
  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

51 52 53 54 55 56 57 58 59 60
  cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
  if (context.GetComputeCapability() >= 70) {
    PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
                                                        CUBLAS_TENSOR_OP_MATH));
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  } else {
    PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
                                                        CUBLAS_DEFAULT_MATH));
  }
61
#endif  // CUDA_VERSION >= 9000
62 63 64 65 66 67 68 69 70

  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
  PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
      CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
      CUDA_R_32F, algo));
71 72 73 74 75 76 77 78 79 80 81 82
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, N));
#endif  // CUDA_VERSION >= 8000
83 84
}

Q
qijun 已提交
85
template <>
Q
QI JUN 已提交
86 87 88 89 90
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
Q
qijun 已提交
91 92
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
93 94
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
95
  cublasOperation_t cuTransA =
Q
qijun 已提交
96
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
97
  cublasOperation_t cuTransB =
Q
qijun 已提交
98
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
99

Q
qijun 已提交
100
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
101 102
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
103 104 105
}

template <>
Q
QI JUN 已提交
106 107 108 109 110
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
Q
qijun 已提交
111 112
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
113 114
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
115
  cublasOperation_t cuTransA =
Q
qijun 已提交
116
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
117
  cublasOperation_t cuTransB =
Q
qijun 已提交
118
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
119
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
120 121
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
122 123
}

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const int lda, const float16* B,
    const int ldb, const float16 beta, float16* C, const int ldc) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
141 142 143
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
144 145 146 147 148
  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, ldc));
}

G
guosheng 已提交
149
template <>
Q
QI JUN 已提交
150 151 152 153 154
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K, const float alpha,
    const float* A, const int lda, const float* B, const int ldb,
    const float beta, float* C, const int ldc) {
G
guosheng 已提交
155 156 157 158 159
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
160 161
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
162 163 164
}

template <>
Q
QI JUN 已提交
165 166 167 168 169
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const double alpha, const double* A, const int lda, const double* B,
    const int ldb, const double beta, double* C, const int ldc) {
G
guosheng 已提交
170 171 172 173 174
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
175 176
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
177 178
}

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
template <>
void matmul<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
    framework::Tensor* matrix_out, float16 beta) {
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
                 "Matrix must all be in CUDAPlace");

  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;

  gemm<platform::CUDADeviceContext, float16>(
      context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
      matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}

Q
qijun 已提交
208
template <>
Q
QI JUN 已提交
209 210 211 212
void matmul<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float alpha,
213
    framework::Tensor* matrix_out, float beta) {
Q
qijun 已提交
214 215 216 217 218 219 220 221 222
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
223
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
224

Q
qijun 已提交
225 226 227
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
228

Q
qijun 已提交
229 230
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
231

Q
QI JUN 已提交
232
  gemm<platform::CUDADeviceContext, float>(
233 234
      context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>());
Q
qijun 已提交
235 236 237
}

template <>
Q
QI JUN 已提交
238 239 240 241
void matmul<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, double alpha,
242
    framework::Tensor* matrix_out, double beta) {
Q
qijun 已提交
243 244 245 246 247 248 249 250 251
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
252
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
253

Q
qijun 已提交
254 255 256 257 258 259
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
260

Q
QI JUN 已提交
261
  gemm<platform::CUDADeviceContext, double>(
262 263
      context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>());
Q
qijun 已提交
264
}
Q
qijun 已提交
265

266 267 268 269 270
template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
Y
Yu Yang 已提交
271 272
    float16* C, const int batchCount, const int64_t strideA,
    const int64_t strideB) {
K
Kexin Zhao 已提交
273
#if CUDA_VERSION >= 8000
274 275 276 277 278 279 280 281 282
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Y
Yu Yang 已提交
283
  const int64_t strideC = M * N;
284 285 286 287 288 289 290

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
291 292 293
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
K
Kexin Zhao 已提交
294

295 296 297
  PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
K
Kexin Zhao 已提交
298 299 300
#else
  PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
301 302
}

M
Markus Kliegl 已提交
303
template <>
Q
QI JUN 已提交
304 305
void batched_gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
306 307
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
Y
Yu Yang 已提交
308 309
    float* C, const int batchCount, const int64_t strideA,
    const int64_t strideB) {
K
Kexin Zhao 已提交
310
#if CUDA_VERSION >= 8000
M
Markus Kliegl 已提交
311 312 313 314 315 316 317 318 319
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Y
Yu Yang 已提交
320
  const int64_t strideC = M * N;
M
Markus Kliegl 已提交
321 322

  PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
Q
QI JUN 已提交
323 324
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
K
Kexin Zhao 已提交
325 326 327
#else
  PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
M
Markus Kliegl 已提交
328 329 330
}

template <>
Q
QI JUN 已提交
331 332
void batched_gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
333 334
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
Y
Yu Yang 已提交
335 336
    double* C, const int batchCount, const int64_t strideA,
    const int64_t strideB) {
K
Kexin Zhao 已提交
337
#if CUDA_VERSION >= 8000
M
Markus Kliegl 已提交
338 339 340 341 342 343 344 345 346
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Y
Yu Yang 已提交
347
  const int64_t strideC = M * N;
M
Markus Kliegl 已提交
348 349

  PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
Q
QI JUN 已提交
350 351
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
K
Kexin Zhao 已提交
352 353 354
#else
  PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
M
Markus Kliegl 已提交
355 356
}

357
template <>
Q
QI JUN 已提交
358 359 360 361
void gemv<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const float alpha, const float* A, const float* B,
    const float beta, float* C) {
362 363
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;

Q
QI JUN 已提交
364 365 366
  PADDLE_ENFORCE(platform::dynload::cublasSgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
367 368 369
}

template <>
Q
QI JUN 已提交
370 371 372 373
void gemv<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const double alpha, const double* A, const double* B,
    const double beta, double* C) {
374
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
Q
QI JUN 已提交
375 376 377
  PADDLE_ENFORCE(platform::dynload::cublasDgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
378 379
}

380
template <>
Q
QI JUN 已提交
381 382 383 384 385
void axpy<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const int n, const float alpha,
    const float* x, float* y) {
  PADDLE_ENFORCE(platform::dynload::cublasSaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
386 387 388
}

template <>
Q
QI JUN 已提交
389 390 391 392 393
void axpy<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const int n, const double alpha,
    const double* x, double* y) {
  PADDLE_ENFORCE(platform::dynload::cublasDaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
394 395
}

K
Kexin Zhao 已提交
396
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
Q
QI JUN 已提交
397 398 399 400 401
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
402

Q
QI JUN 已提交
403 404 405
#define DEFINE_GPU_TRANS(RANK)                                         \
  template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
  template struct Transpose<platform::CUDADeviceContext, double, RANK>;
406 407 408 409 410 411 412

DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);
Q
qijun 已提交
413

414 415
struct TensorSetConstantGPU {
  TensorSetConstantGPU(const platform::DeviceContext& context,
D
dangqingqing 已提交
416
                       framework::Tensor* tensor, float value)
417 418 419 420
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename T>
  void operator()() const {
Q
QI JUN 已提交
421 422 423
    SetConstant<platform::CUDADeviceContext, T> functor;
    functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
            tensor_, static_cast<T>(value_));
424 425 426 427 428 429 430 431
  }

  const platform::DeviceContext& context_;
  framework::Tensor* tensor_;
  float value_;
};

template <>
D
dzhwinter 已提交
432
void set_constant_with_place<platform::CUDAPlace>(
433 434 435
    const platform::DeviceContext& context, framework::Tensor* tensor,
    float value) {
  framework::VisitDataType(framework::ToDataType(tensor->type()),
436
                           TensorSetConstantGPU(context, tensor, value));
437 438
}

Q
qingqing01 已提交
439
template <typename T>
Q
qingqing01 已提交
440 441 442
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width,
                                 int num) {
  T tmp = 1.0 / width;
Q
qingqing01 已提交
443 444
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
Q
qingqing01 已提交
445 446 447
    int h = i * tmp;
    int w = i - h * width;
    c[i] = a[i] + b[w];
Q
qingqing01 已提交
448 449 450 451 452 453 454 455 456
  }
}

template <typename T>
struct RowwiseAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& vector, framework::Tensor* output) {
    auto in_dims = input.dims();
Q
qingqing01 已提交
457 458 459
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(vector.numel(), size);
    PADDLE_ENFORCE_EQ(output->dims(), in_dims);
Q
qingqing01 已提交
460 461 462
    int blocks = 512;
    int grids = (input.numel() + blocks - 1) / blocks;
    RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
Q
qingqing01 已提交
463 464
        input.data<T>(), vector.data<T>(), output->data<T>(),
        static_cast<int>(in_dims[1]), static_cast<int>(input.numel()));
Q
qingqing01 已提交
465 466 467
  }
};

Q
QI JUN 已提交
468 469 470
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
Y
yangyaming 已提交
471 472
template struct ColwiseSum<platform::CUDADeviceContext, int>;
template struct ColwiseSum<platform::CUDADeviceContext, int64_t>;
Q
QI JUN 已提交
473 474
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
475 476
// and only failed for this case. So reimplemented it.
template <>
Q
QI JUN 已提交
477 478
void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
479 480 481 482 483 484
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), size);
  framework::Tensor one;
  one.mutable_data<double>({in_dims[0]}, context.GetPlace());
Q
QI JUN 已提交
485
  SetConstant<platform::CUDADeviceContext, double> set;
486
  set(context, &one, static_cast<double>(1.0));
Q
QI JUN 已提交
487 488 489 490
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]),
      1.0, input.data<double>(), one.data<double>(), 0.0,
      vector->data<double>());
491
}
492

C
chengduoZH 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
  framework::Tensor one;
  one.mutable_data<double>({size}, context.GetPlace());
  SetConstant<platform::CUDADeviceContext, double> set;
  set(context, &one, static_cast<double>(1.0));
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]),
      1.0, one.data<double>(), input.data<double>(), 0.0,
      vector->data<double>());
}

template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;

Q
qijun 已提交
518 519 520
}  // namespace math
}  // namespace operators
}  // namespace paddle