diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 44fd739fb1d161c6c7d6ab1cc611c59220280a4e..249ab6e9037c4d996bd0e94ba22723a8a4328c71 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/float16.h" @@ -161,7 +162,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C, const int batchCount, const int strideA, const int strideB) { + float16* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { PADDLE_THROW("float16 batched_gemm not supported on CPU"); } @@ -172,7 +174,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int strideA, const int strideB) { + float* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; @@ -194,7 +197,8 @@ void batched_gemm( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int strideA, const int strideB) { + double* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 9badf26c9bb80acad029be3d1b63377cef63d929..2aa819625e0f5213a6001908e715bcc73d4747c3 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" @@ -267,7 +268,8 @@ void batched_gemm( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C, const int batchCount, const int strideA, const int strideB) { + float16* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { #if CUDA_VERSION >= 8000 // Note that cublas follows fortran order, so the order is different from // the cblas convention. @@ -278,7 +280,7 @@ void batched_gemm( (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int strideC = M * N; + const int64_t strideC = M * N; const half h_alpha = static_cast(alpha); const half h_beta = static_cast(beta); @@ -303,7 +305,8 @@ void batched_gemm( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, const float* B, const float beta, - float* C, const int batchCount, const int strideA, const int strideB) { + float* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { #if CUDA_VERSION >= 8000 // Note that cublas follows fortran order, so the order is different from // the cblas convention. @@ -314,7 +317,7 @@ void batched_gemm( (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int strideC = M * N; + const int64_t strideC = M * N; PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, @@ -329,7 +332,8 @@ void batched_gemm( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, const int batchCount, const int strideA, const int strideB) { + double* C, const int batchCount, const int64_t strideA, + const int64_t strideB) { #if CUDA_VERSION >= 8000 // Note that cublas follows fortran order, so the order is different from // the cblas convention. @@ -340,7 +344,7 @@ void batched_gemm( (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const int strideC = M * N; + const int64_t strideC = M * N; PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index cdbc7bfb37e83c6c2b696ba010277c9eec49f2a8..cdd02974722045457aacdfa517c147751185f332 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -26,7 +26,7 @@ limitations under the License. */ #ifndef LAPACK_FOUND extern "C" { -#include +#include // NOLINT int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, int* ipiv); int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, @@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #endif #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" @@ -78,8 +79,8 @@ template void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, - const T beta, T* C, const int batchCount, const int strideA, - const int strideB); + const T beta, T* C, const int batchCount, + const int64_t strideA, const int64_t strideB); template void gemv(const DeviceContext& context, const bool trans_a, const int M,