From 52b52ba80cc1ddd47ed6c4e1a89d747f13fec283 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 14:50:02 +0800 Subject: [PATCH] fix gpu build error --- paddle/operators/math/math_function.cu | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 12ddd2146f1..50fc9939b1a 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -35,15 +35,17 @@ void gemm(const CBLAS_TRANSPOSE transA, PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -54,7 +56,7 @@ void gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> -- GitLab