math_function.cu 21.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#define EIGEN_USE_GPU
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
19
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
20

Q
qijun 已提交
21 22 23 24
namespace paddle {
namespace operators {
namespace math {

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
using float16 = paddle::platform::float16;

template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

K
Kexin Zhao 已提交
42 43
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
44 45
                    "cublas fp16 gemm requires GPU compute capability >= 53");

46 47 48 49
#if CUDA_VERSION >= 8000
  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

50 51 52 53 54 55 56 57 58 59
  cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
  if (context.GetComputeCapability() >= 70) {
    PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
                                                        CUBLAS_TENSOR_OP_MATH));
    algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  } else {
    PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
                                                        CUBLAS_DEFAULT_MATH));
  }
60
#endif  // CUDA_VERSION >= 9000
61 62 63 64 65 66 67 68 69

  // cublasHgemm does true FP16 computation which is slow for non-Volta
  // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
  // input/output in fp16, computation in fp32, which can also be accelerated
  // using tensor cores in volta GPUs.
  PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
      CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
      CUDA_R_32F, algo));
70 71 72 73 74 75 76 77 78 79 80 81
#else
  // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, N));
#endif  // CUDA_VERSION >= 8000
82 83
}

Q
qijun 已提交
84
template <>
Q
QI JUN 已提交
85 86 87 88 89
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
Q
qijun 已提交
90 91
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
92 93
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
94
  cublasOperation_t cuTransA =
Q
qijun 已提交
95
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
96
  cublasOperation_t cuTransB =
Q
qijun 已提交
97
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
98

Q
qijun 已提交
99
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
100 101
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
102 103 104
}

template <>
Q
QI JUN 已提交
105 106 107 108 109
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
Q
qijun 已提交
110 111
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
112 113
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
114
  cublasOperation_t cuTransA =
Q
qijun 已提交
115
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
116
  cublasOperation_t cuTransB =
Q
qijun 已提交
117
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
118
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
119 120
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
121 122
}

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const int lda, const float16* B,
    const int ldb, const float16 beta, float16* C, const int ldc) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
140 141 142
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
143 144 145 146 147
  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, ldc));
}

G
guosheng 已提交
148
template <>
Q
QI JUN 已提交
149 150 151 152 153
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K, const float alpha,
    const float* A, const int lda, const float* B, const int ldb,
    const float beta, float* C, const int ldc) {
G
guosheng 已提交
154 155 156 157 158
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
159 160
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
161 162 163
}

template <>
Q
QI JUN 已提交
164 165 166 167 168
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const double alpha, const double* A, const int lda, const double* B,
    const int ldb, const double beta, double* C, const int ldc) {
G
guosheng 已提交
169 170 171 172 173
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
174 175
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
176 177
}

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
template <>
void matmul<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
    framework::Tensor* matrix_out, float16 beta) {
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
                 "Matrix must all be in CUDAPlace");

  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;

  gemm<platform::CUDADeviceContext, float16>(
      context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
      matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}

Q
qijun 已提交
207
template <>
Q
QI JUN 已提交
208 209 210 211
void matmul<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float alpha,
212
    framework::Tensor* matrix_out, float beta) {
Q
qijun 已提交
213 214 215 216 217 218 219 220 221
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
222
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
223

Q
qijun 已提交
224 225 226
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
227

Q
qijun 已提交
228 229
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
230

Q
QI JUN 已提交
231
  gemm<platform::CUDADeviceContext, float>(
232 233
      context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>());
Q
qijun 已提交
234 235 236
}

template <>
Q
QI JUN 已提交
237 238 239 240
void matmul<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, double alpha,
241
    framework::Tensor* matrix_out, double beta) {
Q
qijun 已提交
242 243 244 245 246 247 248 249 250
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
251
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
252

Q
qijun 已提交
253 254 255 256 257 258
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
259

Q
QI JUN 已提交
260
  gemm<platform::CUDADeviceContext, double>(
261 262
      context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>());
Q
qijun 已提交
263
}
Q
qijun 已提交
264

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
288 289 290
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
K
Kexin Zhao 已提交
291 292

#if CUDA_VERSION >= 8000
293 294 295
  PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
K
Kexin Zhao 已提交
296 297 298
#else
  PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
#endif
299 300
}

M
Markus Kliegl 已提交
301
template <>
Q
QI JUN 已提交
302 303
void batched_gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

K
Kexin Zhao 已提交
318
#if CUDA_VERSION >= 8000
M
Markus Kliegl 已提交
319
  PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
Q
QI JUN 已提交
320 321
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
K
Kexin Zhao 已提交
322 323 324
#else
  PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
#endif
M
Markus Kliegl 已提交
325 326 327
}

template <>
Q
QI JUN 已提交
328 329
void batched_gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342 343
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

K
Kexin Zhao 已提交
344
#if CUDA_VERSION >= 8000
M
Markus Kliegl 已提交
345
  PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
Q
QI JUN 已提交
346 347
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
K
Kexin Zhao 已提交
348 349 350
#else
  PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
#endif
M
Markus Kliegl 已提交
351 352
}

353
template <>
Q
QI JUN 已提交
354 355 356 357
void gemv<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const float alpha, const float* A, const float* B,
    const float beta, float* C) {
358 359
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;

Q
QI JUN 已提交
360 361 362
  PADDLE_ENFORCE(platform::dynload::cublasSgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
363 364 365
}

template <>
Q
QI JUN 已提交
366 367 368 369
void gemv<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const double alpha, const double* A, const double* B,
    const double beta, double* C) {
370
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
Q
QI JUN 已提交
371 372 373
  PADDLE_ENFORCE(platform::dynload::cublasDgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
374 375
}

376
template <>
Q
QI JUN 已提交
377 378 379 380 381
void axpy<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const int n, const float alpha,
    const float* x, float* y) {
  PADDLE_ENFORCE(platform::dynload::cublasSaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
382 383 384
}

template <>
Q
QI JUN 已提交
385 386 387 388 389
void axpy<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const int n, const double alpha,
    const double* x, double* y) {
  PADDLE_ENFORCE(platform::dynload::cublasDaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
390 391
}

K
Kexin Zhao 已提交
392
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
Q
QI JUN 已提交
393 394 395 396 397
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
398

Q
QI JUN 已提交
399 400 401
#define DEFINE_GPU_TRANS(RANK)                                         \
  template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
  template struct Transpose<platform::CUDADeviceContext, double, RANK>;
402 403 404 405 406 407 408

DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);
Q
qijun 已提交
409

410 411
struct TensorSetConstantGPU {
  TensorSetConstantGPU(const platform::DeviceContext& context,
D
dangqingqing 已提交
412
                       framework::Tensor* tensor, float value)
413 414 415 416
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename T>
  void operator()() const {
Q
QI JUN 已提交
417 418 419
    SetConstant<platform::CUDADeviceContext, T> functor;
    functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
            tensor_, static_cast<T>(value_));
420 421 422 423 424 425 426 427
  }

  const platform::DeviceContext& context_;
  framework::Tensor* tensor_;
  float value_;
};

template <>
D
dzhwinter 已提交
428
void set_constant_with_place<platform::CUDAPlace>(
429 430 431
    const platform::DeviceContext& context, framework::Tensor* tensor,
    float value) {
  framework::VisitDataType(framework::ToDataType(tensor->type()),
432
                           TensorSetConstantGPU(context, tensor, value));
433 434
}

Q
qingqing01 已提交
435
template <typename T>
Q
qingqing01 已提交
436 437 438
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width,
                                 int num) {
  T tmp = 1.0 / width;
Q
qingqing01 已提交
439 440
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
Q
qingqing01 已提交
441 442 443
    int h = i * tmp;
    int w = i - h * width;
    c[i] = a[i] + b[w];
Q
qingqing01 已提交
444 445 446 447 448 449 450 451 452
  }
}

template <typename T>
struct RowwiseAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& vector, framework::Tensor* output) {
    auto in_dims = input.dims();
Q
qingqing01 已提交
453 454 455
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(vector.numel(), size);
    PADDLE_ENFORCE_EQ(output->dims(), in_dims);
Q
qingqing01 已提交
456 457 458
    int blocks = 512;
    int grids = (input.numel() + blocks - 1) / blocks;
    RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
Q
qingqing01 已提交
459 460
        input.data<T>(), vector.data<T>(), output->data<T>(),
        static_cast<int>(in_dims[1]), static_cast<int>(input.numel()));
Q
qingqing01 已提交
461 462 463
  }
};

Q
QI JUN 已提交
464 465 466
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
Y
yangyaming 已提交
467 468
template struct ColwiseSum<platform::CUDADeviceContext, int>;
template struct ColwiseSum<platform::CUDADeviceContext, int64_t>;
Q
QI JUN 已提交
469 470
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
471 472
// and only failed for this case. So reimplemented it.
template <>
Q
QI JUN 已提交
473 474
void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
475 476 477 478 479 480
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), size);
  framework::Tensor one;
  one.mutable_data<double>({in_dims[0]}, context.GetPlace());
Q
QI JUN 已提交
481
  SetConstant<platform::CUDADeviceContext, double> set;
482
  set(context, &one, static_cast<double>(1.0));
Q
QI JUN 已提交
483 484 485 486
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]),
      1.0, input.data<double>(), one.data<double>(), 0.0,
      vector->data<double>());
487
}
488

C
chengduoZH 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
  framework::Tensor one;
  one.mutable_data<double>({size}, context.GetPlace());
  SetConstant<platform::CUDADeviceContext, double> set;
  set(context, &one, static_cast<double>(1.0));
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]),
      1.0, one.data<double>(), input.data<double>(), 0.0,
      vector->data<double>());
}

template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;

Q
qijun 已提交
514 515 516
}  // namespace math
}  // namespace operators
}  // namespace paddle