From 8b7d48bc0ef4ee029f8cea087500624cf4dc01c1 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 06:47:56 +0000 Subject: [PATCH] fix gpu build error --- paddle/operators/math/math_function.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 12ddd2146..d36e6e6a2 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -35,15 +35,15 @@ void gemm(const CBLAS_TRANSPOSE transA, PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> void gemm( const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { + const double* B, const double beta, double* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -54,7 +54,7 @@ void gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> -- GitLab