diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 12ddd2146f1619bb78bcf03e73b3ac8cf6576044..d36e6e6a2c2c05e3e663c9ca68b4f2e938def06b 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -35,15 +35,15 @@ void gemm(const CBLAS_TRANSPOSE transA, PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> void gemm( const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { + const double* B, const double beta, double* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -54,7 +54,7 @@ void gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <>