未验证 提交 64babc9a 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #10189 from reyoung/feature/fix_matmul_bug

Fix batch_gemm bugs
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>( ...@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta, const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) { float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
PADDLE_THROW("float16 batched_gemm not supported on CPU"); PADDLE_THROW("float16 batched_gemm not supported on CPU");
} }
...@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>( ...@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>( ...@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
...@@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>( ...@@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) { for (int k = 0; k < batchCount; ++k) {
const float* Ak = &A[k * strideA]; const float* Ak = &A[k * strideA];
const float* Bk = &B[k * strideB]; const float* Bk = &B[k * strideB];
...@@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>( ...@@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
for (int k = 0; k < batchCount; ++k) { for (int k = 0; k < batchCount; ++k) {
const double* Ak = &A[k * strideA]; const double* Ak = &A[k * strideA];
const double* Bk = &B[k * strideB]; const double* Bk = &B[k * strideB];
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/operators/math/math_function_impl.h"
...@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float16 alpha, const float16* A, const float16* B, const float16 beta, const float16 alpha, const float16* A, const float16* B, const float16 beta,
float16* C, const int batchCount, const int strideA, const int strideB) { float16* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
...@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int64_t strideC = M * N;
const half h_alpha = static_cast<const half>(alpha); const half h_alpha = static_cast<const half>(alpha);
const half h_beta = static_cast<const half>(beta); const half h_beta = static_cast<const half>(beta);
...@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta, const float alpha, const float* A, const float* B, const float beta,
float* C, const int batchCount, const int strideA, const int strideB) { float* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
...@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>( ...@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int64_t strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
...@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta, const double alpha, const double* A, const double* B, const double beta,
double* C, const int batchCount, const int strideA, const int strideB) { double* C, const int batchCount, const int64_t strideA,
const int64_t strideB) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
...@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>( ...@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int strideC = M * N; const int64_t strideC = M * N;
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
#ifndef LAPACK_FOUND #ifndef LAPACK_FOUND
extern "C" { extern "C" {
#include <cblas.h> #include <cblas.h> // NOLINT
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv); int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
...@@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#endif #endif
#include <cmath> #include <cmath>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -78,8 +79,8 @@ template <typename DeviceContext, typename T> ...@@ -78,8 +79,8 @@ template <typename DeviceContext, typename T>
void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M, const int N, const CBLAS_TRANSPOSE transB, const int M, const int N,
const int K, const T alpha, const T* A, const T* B, const int K, const T alpha, const T* A, const T* B,
const T beta, T* C, const int batchCount, const int strideA, const T beta, T* C, const int batchCount,
const int strideB); const int64_t strideA, const int64_t strideB);
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void gemv(const DeviceContext& context, const bool trans_a, const int M, void gemv(const DeviceContext& context, const bool trans_a, const int M,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册