math_function.cu 20.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#define EIGEN_USE_GPU
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
19
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
20

Q
qijun 已提交
21 22 23 24
namespace paddle {
namespace operators {
namespace math {

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
using float16 = paddle::platform::float16;

template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
48 49 50
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
51 52 53 54 55
  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, N));
}

Q
qijun 已提交
56
template <>
Q
QI JUN 已提交
57 58 59 60 61
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
Q
qijun 已提交
62 63
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
64 65
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
66
  cublasOperation_t cuTransA =
Q
qijun 已提交
67
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
68
  cublasOperation_t cuTransB =
Q
qijun 已提交
69
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
70

Q
qijun 已提交
71
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
72 73
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
74 75 76
}

template <>
Q
QI JUN 已提交
77 78 79 80 81
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
Q
qijun 已提交
82 83
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
84 85
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
86
  cublasOperation_t cuTransA =
Q
qijun 已提交
87
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
88
  cublasOperation_t cuTransB =
Q
qijun 已提交
89
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
90
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
91 92
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
93 94
}

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const int lda, const float16* B,
    const int ldb, const float16 beta, float16* C, const int ldc) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
112 113 114
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
115 116 117 118 119
  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, ldc));
}

G
guosheng 已提交
120
template <>
Q
QI JUN 已提交
121 122 123 124 125
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K, const float alpha,
    const float* A, const int lda, const float* B, const int ldb,
    const float beta, float* C, const int ldc) {
G
guosheng 已提交
126 127 128 129 130
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
131 132
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
133 134 135
}

template <>
Q
QI JUN 已提交
136 137 138 139 140
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const double alpha, const double* A, const int lda, const double* B,
    const int ldb, const double beta, double* C, const int ldc) {
G
guosheng 已提交
141 142 143 144 145
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
146 147
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
148 149
}

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
template <>
void matmul<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
    framework::Tensor* matrix_out, float16 beta) {
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
                 "Matrix must all be in CUDAPlace");

  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;

  gemm<platform::CUDADeviceContext, float16>(
      context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
      matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}

Q
qijun 已提交
179
template <>
Q
QI JUN 已提交
180 181 182 183
void matmul<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float alpha,
184
    framework::Tensor* matrix_out, float beta) {
Q
qijun 已提交
185 186 187 188 189 190 191 192 193
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
194
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
195

Q
qijun 已提交
196 197 198
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
199

Q
qijun 已提交
200 201
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
202

Q
QI JUN 已提交
203
  gemm<platform::CUDADeviceContext, float>(
204 205
      context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>());
Q
qijun 已提交
206 207 208
}

template <>
Q
QI JUN 已提交
209 210 211 212
void matmul<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, double alpha,
213
    framework::Tensor* matrix_out, double beta) {
Q
qijun 已提交
214 215 216 217 218 219 220 221 222
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
223
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
224

Q
qijun 已提交
225 226 227 228 229 230
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
231

Q
QI JUN 已提交
232
  gemm<platform::CUDADeviceContext, double>(
233 234
      context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>());
Q
qijun 已提交
235
}
Q
qijun 已提交
236

237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

K
Kexin Zhao 已提交
260 261 262
  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
                    "cublas Hgemm requires GPU compute capability >= 53");
263 264 265 266 267
  PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
}

M
Markus Kliegl 已提交
268
template <>
Q
QI JUN 已提交
269 270
void batched_gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
Q
QI JUN 已提交
286 287
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
M
Markus Kliegl 已提交
288 289 290
}

template <>
Q
QI JUN 已提交
291 292
void batched_gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
Q
QI JUN 已提交
308 309
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
M
Markus Kliegl 已提交
310 311
}

312
template <>
Q
QI JUN 已提交
313 314 315 316
void gemv<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const float alpha, const float* A, const float* B,
    const float beta, float* C) {
317 318
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;

Q
QI JUN 已提交
319 320 321
  PADDLE_ENFORCE(platform::dynload::cublasSgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
322 323 324
}

template <>
Q
QI JUN 已提交
325 326 327 328
void gemv<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const double alpha, const double* A, const double* B,
    const double beta, double* C) {
329
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
Q
QI JUN 已提交
330 331 332
  PADDLE_ENFORCE(platform::dynload::cublasDgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
333 334
}

335
template <>
Q
QI JUN 已提交
336 337 338 339 340
void axpy<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const int n, const float alpha,
    const float* x, float* y) {
  PADDLE_ENFORCE(platform::dynload::cublasSaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
341 342 343
}

template <>
Q
QI JUN 已提交
344 345 346 347 348
void axpy<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const int n, const double alpha,
    const double* x, double* y) {
  PADDLE_ENFORCE(platform::dynload::cublasDaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
349 350
}

Q
QI JUN 已提交
351 352 353 354 355
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
356

Q
QI JUN 已提交
357 358 359
#define DEFINE_GPU_TRANS(RANK)                                         \
  template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
  template struct Transpose<platform::CUDADeviceContext, double, RANK>;
360 361 362 363 364 365 366

DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);
Q
qijun 已提交
367

368 369
struct TensorSetConstantGPU {
  TensorSetConstantGPU(const platform::DeviceContext& context,
D
dangqingqing 已提交
370
                       framework::Tensor* tensor, float value)
371 372 373 374
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename T>
  void operator()() const {
Q
QI JUN 已提交
375 376 377
    SetConstant<platform::CUDADeviceContext, T> functor;
    functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
            tensor_, static_cast<T>(value_));
378 379 380 381 382 383 384 385
  }

  const platform::DeviceContext& context_;
  framework::Tensor* tensor_;
  float value_;
};

template <>
D
dzhwinter 已提交
386
void set_constant_with_place<platform::CUDAPlace>(
387 388 389
    const platform::DeviceContext& context, framework::Tensor* tensor,
    float value) {
  framework::VisitDataType(framework::ToDataType(tensor->type()),
390
                           TensorSetConstantGPU(context, tensor, value));
391 392
}

Q
qingqing01 已提交
393
template <typename T>
Q
qingqing01 已提交
394 395 396
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width,
                                 int num) {
  T tmp = 1.0 / width;
Q
qingqing01 已提交
397 398
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
Q
qingqing01 已提交
399 400 401
    int h = i * tmp;
    int w = i - h * width;
    c[i] = a[i] + b[w];
Q
qingqing01 已提交
402 403 404 405 406 407 408 409 410
  }
}

template <typename T>
struct RowwiseAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& vector, framework::Tensor* output) {
    auto in_dims = input.dims();
Q
qingqing01 已提交
411 412 413
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(vector.numel(), size);
    PADDLE_ENFORCE_EQ(output->dims(), in_dims);
Q
qingqing01 已提交
414 415 416
    int blocks = 512;
    int grids = (input.numel() + blocks - 1) / blocks;
    RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
Q
qingqing01 已提交
417 418
        input.data<T>(), vector.data<T>(), output->data<T>(),
        static_cast<int>(in_dims[1]), static_cast<int>(input.numel()));
Q
qingqing01 已提交
419 420 421
  }
};

Q
QI JUN 已提交
422 423 424
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
Y
yangyaming 已提交
425 426
template struct ColwiseSum<platform::CUDADeviceContext, int>;
template struct ColwiseSum<platform::CUDADeviceContext, int64_t>;
Q
QI JUN 已提交
427 428
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
429 430
// and only failed for this case. So reimplemented it.
template <>
Q
QI JUN 已提交
431 432
void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
433 434 435 436 437 438
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), size);
  framework::Tensor one;
  one.mutable_data<double>({in_dims[0]}, context.GetPlace());
Q
QI JUN 已提交
439
  SetConstant<platform::CUDADeviceContext, double> set;
440
  set(context, &one, static_cast<double>(1.0));
Q
QI JUN 已提交
441 442 443 444
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]),
      1.0, input.data<double>(), one.data<double>(), 0.0,
      vector->data<double>());
445
}
446

C
chengduoZH 已提交
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
  framework::Tensor one;
  one.mutable_data<double>({size}, context.GetPlace());
  SetConstant<platform::CUDADeviceContext, double> set;
  set(context, &one, static_cast<double>(1.0));
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]),
      1.0, one.data<double>(), input.data<double>(), 0.0,
      vector->data<double>());
}

template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;

Q
qijun 已提交
472 473 474
}  // namespace math
}  // namespace operators
}  // namespace paddle