math_function.cu 19.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#define EIGEN_USE_GPU
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
19
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
20

Q
qijun 已提交
21 22 23 24
namespace paddle {
namespace operators {
namespace math {

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
using float16 = paddle::platform::float16;

template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, N));
}

Q
qijun 已提交
53
template <>
Q
QI JUN 已提交
54 55 56 57 58
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
Q
qijun 已提交
59 60
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
61 62
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
63
  cublasOperation_t cuTransA =
Q
qijun 已提交
64
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
65
  cublasOperation_t cuTransB =
Q
qijun 已提交
66
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
67

Q
qijun 已提交
68
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
69 70
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
71 72 73
}

template <>
Q
QI JUN 已提交
74 75 76 77 78
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
Q
qijun 已提交
79 80
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
Q
qijun 已提交
81 82
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
Q
qijun 已提交
83
  cublasOperation_t cuTransA =
Q
qijun 已提交
84
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
85
  cublasOperation_t cuTransB =
Q
qijun 已提交
86
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
Q
qijun 已提交
87
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
88 89
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, N));
Q
qijun 已提交
90 91
}

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
template <>
void gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const int lda, const float16* B,
    const int ldb, const float16 beta, float16* C, const int ldc) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

  PADDLE_ENFORCE(platform::dynload::cublasHgemm(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      h_A, lda, &h_beta, h_C, ldc));
}

G
guosheng 已提交
114
template <>
Q
QI JUN 已提交
115 116 117 118 119
void gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K, const float alpha,
    const float* A, const int lda, const float* B, const int ldb,
    const float beta, float* C, const int ldc) {
G
guosheng 已提交
120 121 122 123 124
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasSgemm(
Q
QI JUN 已提交
125 126
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
127 128 129
}

template <>
Q
QI JUN 已提交
130 131 132 133 134
void gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool transA,
    const bool transB, const int M, const int N, const int K,
    const double alpha, const double* A, const int lda, const double* B,
    const int ldb, const double beta, double* C, const int ldc) {
G
guosheng 已提交
135 136 137 138 139
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
  PADDLE_ENFORCE(platform::dynload::cublasDgemm(
Q
QI JUN 已提交
140 141
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
      lda, &beta, C, ldc));
G
guosheng 已提交
142 143
}

144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <>
void matmul<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
    framework::Tensor* matrix_out, float16 beta) {
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
                 "Matrix must all be in CUDAPlace");

  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;

  gemm<platform::CUDADeviceContext, float16>(
      context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
      matrix_b.data<float16>(), beta, matrix_out->data<float16>());
}

Q
qijun 已提交
173
template <>
Q
QI JUN 已提交
174 175 176 177
void matmul<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, float alpha,
178
    framework::Tensor* matrix_out, float beta) {
Q
qijun 已提交
179 180 181 182 183 184 185 186 187
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
188
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
189

Q
qijun 已提交
190 191 192
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];
Q
qijun 已提交
193

Q
qijun 已提交
194 195
  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
196

Q
QI JUN 已提交
197
  gemm<platform::CUDADeviceContext, float>(
198 199
      context, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
      matrix_b.data<float>(), beta, matrix_out->data<float>());
Q
qijun 已提交
200 201 202
}

template <>
Q
QI JUN 已提交
203 204 205 206
void matmul<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context,
    const framework::Tensor& matrix_a, bool trans_a,
    const framework::Tensor& matrix_b, bool trans_b, double alpha,
207
    framework::Tensor* matrix_out, double beta) {
Q
qijun 已提交
208 209 210 211 212 213 214 215 216
  auto dim_a = matrix_a.dims();
  auto dim_b = matrix_b.dims();
  auto dim_out = matrix_out->dims();
  PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
                 "The input and output of matmul be matrix");

  PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
                     platform::is_gpu_place(matrix_b.place()) &&
                     platform::is_gpu_place(matrix_out->place()),
D
dzhwinter 已提交
217
                 "Matrix must all be in CUDAPlace");
Q
qijun 已提交
218

Q
qijun 已提交
219 220 221 222 223 224
  int M = dim_out[0];
  int N = dim_out[1];
  int K = (trans_a == false) ? dim_a[1] : dim_a[0];

  CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
Q
qijun 已提交
225

Q
QI JUN 已提交
226
  gemm<platform::CUDADeviceContext, double>(
227 228
      context, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
      matrix_b.data<double>(), beta, matrix_out->data<double>());
Q
qijun 已提交
229
}
Q
qijun 已提交
230

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
template <>
void batched_gemm<platform::CUDADeviceContext, float16>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float16 alpha, const float16* A, const float16* B, const float16 beta,
    float16* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  const half h_alpha = static_cast<const half>(alpha);
  const half h_beta = static_cast<const half>(beta);
  const half* h_A = reinterpret_cast<const half*>(A);
  const half* h_B = reinterpret_cast<const half*>(B);
  half* h_C = reinterpret_cast<half*>(C);

  PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
      strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
}

M
Markus Kliegl 已提交
259
template <>
Q
QI JUN 已提交
260 261
void batched_gemm<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
Q
QI JUN 已提交
277 278
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
M
Markus Kliegl 已提交
279 280 281
}

template <>
Q
QI JUN 已提交
282 283
void batched_gemm<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
M
Markus Kliegl 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
    const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C, const int batchCount, const int strideA, const int strideB) {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  cublasOperation_t cuTransA =
      (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const int strideC = M * N;

  PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
Q
QI JUN 已提交
299 300
      context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
      strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
M
Markus Kliegl 已提交
301 302
}

303
template <>
Q
QI JUN 已提交
304 305 306 307
void gemv<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const float alpha, const float* A, const float* B,
    const float beta, float* C) {
308 309
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;

Q
QI JUN 已提交
310 311 312
  PADDLE_ENFORCE(platform::dynload::cublasSgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
313 314 315
}

template <>
Q
QI JUN 已提交
316 317 318 319
void gemv<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const bool trans_a, const int M,
    const int N, const double alpha, const double* A, const double* B,
    const double beta, double* C) {
320
  cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
Q
QI JUN 已提交
321 322 323
  PADDLE_ENFORCE(platform::dynload::cublasDgemv(context.cublas_handle(),
                                                cuTransA, N, M, &alpha, A, N, B,
                                                1, &beta, C, 1));
324 325
}

326
template <>
Q
QI JUN 已提交
327 328 329 330 331
void axpy<platform::CUDADeviceContext, float>(
    const platform::CUDADeviceContext& context, const int n, const float alpha,
    const float* x, float* y) {
  PADDLE_ENFORCE(platform::dynload::cublasSaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
332 333 334
}

template <>
Q
QI JUN 已提交
335 336 337 338 339
void axpy<platform::CUDADeviceContext, double>(
    const platform::CUDADeviceContext& context, const int n, const double alpha,
    const double* x, double* y) {
  PADDLE_ENFORCE(platform::dynload::cublasDaxpy(context.cublas_handle(), n,
                                                &alpha, x, 1, y, 1));
340 341
}

Q
QI JUN 已提交
342 343 344 345 346
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
347

Q
QI JUN 已提交
348 349 350
#define DEFINE_GPU_TRANS(RANK)                                         \
  template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
  template struct Transpose<platform::CUDADeviceContext, double, RANK>;
351 352 353 354 355 356 357

DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);
Q
qijun 已提交
358

359 360
struct TensorSetConstantGPU {
  TensorSetConstantGPU(const platform::DeviceContext& context,
D
dangqingqing 已提交
361
                       framework::Tensor* tensor, float value)
362 363 364 365
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename T>
  void operator()() const {
Q
QI JUN 已提交
366 367 368
    SetConstant<platform::CUDADeviceContext, T> functor;
    functor(reinterpret_cast<const platform::CUDADeviceContext&>(context_),
            tensor_, static_cast<T>(value_));
369 370 371 372 373 374 375 376
  }

  const platform::DeviceContext& context_;
  framework::Tensor* tensor_;
  float value_;
};

template <>
D
dzhwinter 已提交
377
void set_constant_with_place<platform::CUDAPlace>(
378 379 380
    const platform::DeviceContext& context, framework::Tensor* tensor,
    float value) {
  framework::VisitDataType(framework::ToDataType(tensor->type()),
381
                           TensorSetConstantGPU(context, tensor, value));
382 383
}

Q
qingqing01 已提交
384
template <typename T>
Q
qingqing01 已提交
385 386 387
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width,
                                 int num) {
  T tmp = 1.0 / width;
Q
qingqing01 已提交
388 389
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
Q
qingqing01 已提交
390 391 392
    int h = i * tmp;
    int w = i - h * width;
    c[i] = a[i] + b[w];
Q
qingqing01 已提交
393 394 395 396 397 398 399 400 401
  }
}

template <typename T>
struct RowwiseAdd<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& vector, framework::Tensor* output) {
    auto in_dims = input.dims();
Q
qingqing01 已提交
402 403 404
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(vector.numel(), size);
    PADDLE_ENFORCE_EQ(output->dims(), in_dims);
Q
qingqing01 已提交
405 406 407
    int blocks = 512;
    int grids = (input.numel() + blocks - 1) / blocks;
    RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
Q
qingqing01 已提交
408 409
        input.data<T>(), vector.data<T>(), output->data<T>(),
        static_cast<int>(in_dims[1]), static_cast<int>(input.numel()));
Q
qingqing01 已提交
410 411 412
  }
};

Q
QI JUN 已提交
413 414 415 416 417
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
418 419
// and only failed for this case. So reimplemented it.
template <>
Q
QI JUN 已提交
420 421
void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
422 423 424 425 426 427
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), size);
  framework::Tensor one;
  one.mutable_data<double>({in_dims[0]}, context.GetPlace());
Q
QI JUN 已提交
428
  SetConstant<platform::CUDADeviceContext, double> set;
429
  set(context, &one, static_cast<double>(1.0));
Q
QI JUN 已提交
430 431 432 433
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[0]), static_cast<int>(in_dims[1]),
      1.0, input.data<double>(), one.data<double>(), 0.0,
      vector->data<double>());
434
}
435

C
chengduoZH 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
    const platform::CUDADeviceContext& context, const framework::Tensor& input,
    framework::Tensor* vector) {
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
  framework::Tensor one;
  one.mutable_data<double>({size}, context.GetPlace());
  SetConstant<platform::CUDADeviceContext, double> set;
  set(context, &one, static_cast<double>(1.0));
  gemv<platform::CUDADeviceContext, double>(
      context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]),
      1.0, one.data<double>(), input.data<double>(), 0.0,
      vector->data<double>());
}

template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;

Q
qijun 已提交
461 462 463
}  // namespace math
}  // namespace operators
}  // namespace paddle