未验证 提交 d5a0d31a 编写于 作者: L Leo Chen 提交者: GitHub

[bf16] pten matmul cuda kernel support bf16 (#39485)

* pten matmul cuda kernel support bf16

* fix pten kernel name

* add matmul_grad bf16 kernel

* add emptylike bf16 kernel

* fix compile

* suppport rocm

* fix error

* fix rocm

* add bf16 header file

* fix compile
上级 f31c2426
...@@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(op_proto_->type(), GetInputArgsNames(), return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetAttrsArgsNames(), GetOutputArgsNames()); GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
} }
std::once_flag kernel_sig_map_init_flag; std::once_flag kernel_sig_map_init_flag;
......
...@@ -813,6 +813,102 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -813,6 +813,102 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <> template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
...@@ -1208,6 +1304,42 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N, ...@@ -1208,6 +1304,42 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
} }
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM( void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...@@ -1306,6 +1438,91 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1306,6 +1438,91 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM( void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...@@ -1356,6 +1573,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM( ...@@ -1356,6 +1573,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
} }
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
......
...@@ -550,6 +550,84 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -550,6 +550,84 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
rocblas_datatype_f16_r, N, rocblas_datatype_f32_r); rocblas_datatype_f16_r, N, rocblas_datatype_f32_r);
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(zhiqiu): 80 has the same meaning for rocm and cuda?
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta,
C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(zhiqiu): 80 has the same meaning for rocm and cuda?
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta,
C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <> template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
...@@ -874,6 +952,39 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N, ...@@ -874,6 +952,39 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
} }
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM( void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...@@ -898,6 +1009,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -898,6 +1009,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
ldc, strideC, batchCount); ldc, strideC, batchCount);
}); });
} }
template <> template <>
template <typename T> template <typename T>
void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
...@@ -925,6 +1037,70 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -925,6 +1037,70 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
}); });
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
const int64_t strideC = M * N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::rocblas_gemm_strided_batched_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r,
lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C,
rocblas_datatype_bf16_r, ldc, strideC, batchCount,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
const int64_t strideC = M * N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::rocblas_gemm_strided_batched_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r,
lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C,
rocblas_datatype_bf16_r, ldc, strideC, batchCount,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM( void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...@@ -935,6 +1111,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -935,6 +1111,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
C[k]); C[k]);
} }
} }
template <> template <>
template <typename T> template <typename T>
void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
...@@ -973,6 +1150,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM( ...@@ -973,6 +1150,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
} }
} }
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
......
...@@ -16,8 +16,10 @@ limitations under the License. */ ...@@ -16,8 +16,10 @@ limitations under the License. */
// NOTE(): support float16 to half in header file. // NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16 #define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/enforce.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -61,6 +63,19 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, ...@@ -61,6 +63,19 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
static_cast<unsigned>(delta), width)); static_cast<unsigned>(delta), width));
} }
template <>
__forceinline__ __device__ bfloat16 CudaShuffleDownSync(unsigned mask,
bfloat16 val, int delta,
int width) {
#if defined(PADDLE_CUDA_BF16)
return bfloat16(__shfl_down_sync(mask, static_cast<nv_bfloat16>(val),
static_cast<unsigned>(delta), width));
#else
PADDLE_ENFORCE(
false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11.");
#endif
}
template <> template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync( __forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) { unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
// NOTE(): support float16 to half in header file. // NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16 #define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -59,6 +60,14 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, ...@@ -59,6 +60,14 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
static_cast<unsigned>(delta), width)); static_cast<unsigned>(delta), width));
} }
template <>
__forceinline__ __device__ bfloat16 CudaShuffleDownSync(unsigned mask,
bfloat16 val, int delta,
int width) {
return bfloat16(__shfl_down(static_cast<float>(val),
static_cast<unsigned>(delta), width));
}
template <> template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync( __forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) { unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
......
...@@ -94,6 +94,7 @@ PT_REGISTER_KERNEL(empty_like, ...@@ -94,6 +94,7 @@ PT_REGISTER_KERNEL(empty_like,
int64_t, int64_t,
bool, bool,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
#endif #endif
...@@ -26,6 +26,7 @@ PT_REGISTER_KERNEL(matmul_grad, ...@@ -26,6 +26,7 @@ PT_REGISTER_KERNEL(matmul_grad,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
......
...@@ -27,5 +27,6 @@ PT_REGISTER_KERNEL(matmul, ...@@ -27,5 +27,6 @@ PT_REGISTER_KERNEL(matmul,
float, float,
double, double,
paddle::platform::float16, paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
...@@ -71,7 +71,7 @@ class TestMatMulV2Op(OpTest): ...@@ -71,7 +71,7 @@ class TestMatMulV2Op(OpTest):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op): class TestMatMulOp2(TestMatMulV2Op):
""" """
case 2 case 2
""" """
...@@ -83,7 +83,7 @@ class TestMatMuklOp2(TestMatMulV2Op): ...@@ -83,7 +83,7 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOp3(TestMatMulV2Op): class TestMatMulOp3(TestMatMulV2Op):
""" """
case 3 case 3
""" """
...@@ -95,7 +95,7 @@ class TestMatMuklOp3(TestMatMulV2Op): ...@@ -95,7 +95,7 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op): class TestMatMulOp4(TestMatMulV2Op):
""" """
case 4 case 4
""" """
...@@ -107,7 +107,7 @@ class TestMatMuklOp4(TestMatMulV2Op): ...@@ -107,7 +107,7 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op): class TestMatMulOp5(TestMatMulV2Op):
""" """
case 5 case 5
""" """
...@@ -119,7 +119,7 @@ class TestMatMuklOp5(TestMatMulV2Op): ...@@ -119,7 +119,7 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op): class TestMatMulOp6(TestMatMulV2Op):
""" """
case 6 case 6
""" """
...@@ -131,7 +131,7 @@ class TestMatMuklOp6(TestMatMulV2Op): ...@@ -131,7 +131,7 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op): class TestMatMulOp7(TestMatMulV2Op):
""" """
case 7 case 7
""" """
...@@ -143,7 +143,7 @@ class TestMatMuklOp7(TestMatMulV2Op): ...@@ -143,7 +143,7 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op): class TestMatMulOp8(TestMatMulV2Op):
""" """
case 8 case 8
""" """
...@@ -155,7 +155,7 @@ class TestMatMuklOp8(TestMatMulV2Op): ...@@ -155,7 +155,7 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op): class TestMatMulOp9(TestMatMulV2Op):
""" """
case 9 case 9
""" """
...@@ -167,7 +167,7 @@ class TestMatMuklOp9(TestMatMulV2Op): ...@@ -167,7 +167,7 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op): class TestMatMulOp10(TestMatMulV2Op):
""" """
case 10 case 10
""" """
...@@ -179,7 +179,7 @@ class TestMatMuklOp10(TestMatMulV2Op): ...@@ -179,7 +179,7 @@ class TestMatMuklOp10(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op): class TestMatMulOp11(TestMatMulV2Op):
""" """
case 11 case 11
""" """
...@@ -191,7 +191,7 @@ class TestMatMuklOp11(TestMatMulV2Op): ...@@ -191,7 +191,7 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op): class TestMatMulOp12(TestMatMulV2Op):
""" """
case 12 case 12
""" """
...@@ -203,7 +203,7 @@ class TestMatMuklOp12(TestMatMulV2Op): ...@@ -203,7 +203,7 @@ class TestMatMuklOp12(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op): class TestMatMulOp13(TestMatMulV2Op):
""" """
case 13 case 13
""" """
...@@ -215,7 +215,7 @@ class TestMatMuklOp13(TestMatMulV2Op): ...@@ -215,7 +215,7 @@ class TestMatMuklOp13(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op): class TestMatMulOp14(TestMatMulV2Op):
""" """
case 14_1 case 14_1
""" """
...@@ -227,7 +227,7 @@ class TestMatMuklOp14(TestMatMulV2Op): ...@@ -227,7 +227,7 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op): class TestMatMulOp15(TestMatMulV2Op):
""" """
case 14_2 case 14_2
""" """
...@@ -239,7 +239,7 @@ class TestMatMuklOp15(TestMatMulV2Op): ...@@ -239,7 +239,7 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp16(TestMatMulV2Op): class TestMatMulOp16(TestMatMulV2Op):
""" """
case 16 : to check the gradient for special case case 16 : to check the gradient for special case
""" """
...@@ -251,7 +251,7 @@ class TestMatMuklOp16(TestMatMulV2Op): ...@@ -251,7 +251,7 @@ class TestMatMuklOp16(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op): class TestMatMulOp17(TestMatMulV2Op):
""" """
case 17 : to check the gradient for special case case 17 : to check the gradient for special case
""" """
...@@ -263,7 +263,7 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -263,7 +263,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op): class TestMatMulOpBroadcast1(TestMatMulV2Op):
""" """
case 14_3 case 14_3
""" """
...@@ -275,7 +275,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op): ...@@ -275,7 +275,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op): class TestMatMulOpBroadcast2(TestMatMulV2Op):
""" """
case 14_4 case 14_4
""" """
...@@ -310,22 +310,22 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5): ...@@ -310,22 +310,22 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5):
create_test_fp16_class(TestMatMulV2Op) create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2) create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMuklOp3) create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMuklOp4) create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMuklOp5) create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMuklOp6) create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMuklOp7) create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMuklOp8) create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMuklOp9) create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMuklOp10) create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMuklOp11) create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMuklOp12) create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMuklOp13) create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMuklOp14) create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMuklOp15) create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMuklOp16) create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMuklOp17) create_test_fp16_class(TestMatMulOp17)
class TestMatMulV2API(unittest.TestCase): class TestMatMulV2API(unittest.TestCase):
......
...@@ -1658,7 +1658,7 @@ class OpTest(unittest.TestCase): ...@@ -1658,7 +1658,7 @@ class OpTest(unittest.TestCase):
for grad in analytic_grads: for grad in analytic_grads:
if grad.dtype == np.uint16: if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad) grad = convert_uint16_to_float(grad)
max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error max_relative_error = 0.04 if max_relative_error < 0.04 else max_relative_error
fp32_analytic_grads.append(grad) fp32_analytic_grads.append(grad)
analytic_grads = fp32_analytic_grads analytic_grads = fp32_analytic_grads
...@@ -1666,7 +1666,7 @@ class OpTest(unittest.TestCase): ...@@ -1666,7 +1666,7 @@ class OpTest(unittest.TestCase):
for grad in numeric_grads: for grad in numeric_grads:
if grad.dtype == np.uint16: if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad) grad = convert_uint16_to_float(grad)
max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error max_relative_error = 0.04 if max_relative_error < 0.04 else max_relative_error
fp32_numeric_grads.append(grad) fp32_numeric_grads.append(grad)
numeric_grads = fp32_numeric_grads numeric_grads = fp32_numeric_grads
......
...@@ -16,7 +16,8 @@ from __future__ import print_function ...@@ -16,7 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from paddle.fluid.tests.unittests.testsuite import create_op
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle import paddle
...@@ -73,17 +74,32 @@ class TestMatMulV2Op(OpTest): ...@@ -73,17 +74,32 @@ class TestMatMulV2Op(OpTest):
self.init_kernel_type() self.init_kernel_type()
self.config() self.config()
self.op_type = "matmul_v2" self.op_type = "matmul_v2"
x = np.random.random(self.x_shape).astype(self.dtype) if self.is_bfloat16_op():
y = np.random.random(self.y_shape).astype(self.dtype) x = np.random.random(self.x_shape).astype(np.float32)
# -0.1 ~ 0.1 y = np.random.random(self.y_shape).astype(np.float32)
x = -0.1 + 0.2 * x else:
y = -0.1 + 0.2 * y x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
x = -0.1 + 0.2 * x
y = -0.1 + 0.2 * y
result = reference_matmul(x, y, self.trans_x, self.trans_y) result = reference_matmul(x, y, self.trans_x, self.trans_y)
result = result.astype(self.dtype) if self.is_bfloat16_op():
self.inputs = { result = result.astype(np.float32)
'X': x, self.inputs = {
'Y': y, 'X': convert_float_to_uint16(x),
} 'Y': convert_float_to_uint16(y),
}
self.inputs_fp32 = {
'X': x,
'Y': y,
}
else:
result = result.astype(self.dtype)
self.inputs = {
'X': x,
'Y': y,
}
self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
self.outputs = {'Out': result} self.outputs = {'Out': result}
...@@ -97,7 +113,7 @@ class TestMatMulV2Op(OpTest): ...@@ -97,7 +113,7 @@ class TestMatMulV2Op(OpTest):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op): class TestMatMulOp2(TestMatMulV2Op):
""" """
case 2 case 2
""" """
...@@ -109,7 +125,7 @@ class TestMatMuklOp2(TestMatMulV2Op): ...@@ -109,7 +125,7 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOp3(TestMatMulV2Op): class TestMatMulOp3(TestMatMulV2Op):
""" """
case 3 case 3
""" """
...@@ -121,7 +137,7 @@ class TestMatMuklOp3(TestMatMulV2Op): ...@@ -121,7 +137,7 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op): class TestMatMulOp4(TestMatMulV2Op):
""" """
case 4 case 4
""" """
...@@ -133,7 +149,7 @@ class TestMatMuklOp4(TestMatMulV2Op): ...@@ -133,7 +149,7 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op): class TestMatMulOp5(TestMatMulV2Op):
""" """
case 5 case 5
""" """
...@@ -145,7 +161,7 @@ class TestMatMuklOp5(TestMatMulV2Op): ...@@ -145,7 +161,7 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op): class TestMatMulOp6(TestMatMulV2Op):
""" """
case 6 case 6
""" """
...@@ -157,7 +173,7 @@ class TestMatMuklOp6(TestMatMulV2Op): ...@@ -157,7 +173,7 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op): class TestMatMulOp7(TestMatMulV2Op):
""" """
case 7 case 7
""" """
...@@ -169,7 +185,7 @@ class TestMatMuklOp7(TestMatMulV2Op): ...@@ -169,7 +185,7 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op): class TestMatMulOp8(TestMatMulV2Op):
""" """
case 8 case 8
""" """
...@@ -181,7 +197,7 @@ class TestMatMuklOp8(TestMatMulV2Op): ...@@ -181,7 +197,7 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op): class TestMatMulOp9(TestMatMulV2Op):
""" """
case 9 case 9
""" """
...@@ -193,7 +209,7 @@ class TestMatMuklOp9(TestMatMulV2Op): ...@@ -193,7 +209,7 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op): class TestMatMulOp10(TestMatMulV2Op):
""" """
case 10 case 10
""" """
...@@ -205,7 +221,7 @@ class TestMatMuklOp10(TestMatMulV2Op): ...@@ -205,7 +221,7 @@ class TestMatMuklOp10(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op): class TestMatMulOp11(TestMatMulV2Op):
""" """
case 11 case 11
""" """
...@@ -217,7 +233,7 @@ class TestMatMuklOp11(TestMatMulV2Op): ...@@ -217,7 +233,7 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op): class TestMatMulOp12(TestMatMulV2Op):
""" """
case 12 case 12
""" """
...@@ -229,7 +245,7 @@ class TestMatMuklOp12(TestMatMulV2Op): ...@@ -229,7 +245,7 @@ class TestMatMuklOp12(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op): class TestMatMulOp13(TestMatMulV2Op):
""" """
case 13 case 13
""" """
...@@ -241,7 +257,7 @@ class TestMatMuklOp13(TestMatMulV2Op): ...@@ -241,7 +257,7 @@ class TestMatMuklOp13(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op): class TestMatMulOp14(TestMatMulV2Op):
""" """
case 14_1 case 14_1
""" """
...@@ -253,7 +269,7 @@ class TestMatMuklOp14(TestMatMulV2Op): ...@@ -253,7 +269,7 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op): class TestMatMulOp15(TestMatMulV2Op):
""" """
case 14_2 case 14_2
""" """
...@@ -265,7 +281,7 @@ class TestMatMuklOp15(TestMatMulV2Op): ...@@ -265,7 +281,7 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp16(TestMatMulV2Op): class TestMatMulOp16(TestMatMulV2Op):
""" """
case 16 : to check the gradient for special case case 16 : to check the gradient for special case
""" """
...@@ -277,7 +293,7 @@ class TestMatMuklOp16(TestMatMulV2Op): ...@@ -277,7 +293,7 @@ class TestMatMuklOp16(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op): class TestMatMulOp17(TestMatMulV2Op):
""" """
case 17 : to check the gradient for special case case 17 : to check the gradient for special case
""" """
...@@ -289,7 +305,7 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -289,7 +305,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op): class TestMatMulOpBroadcast1(TestMatMulV2Op):
""" """
case 14_3 case 14_3
""" """
...@@ -301,7 +317,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op): ...@@ -301,7 +317,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op): class TestMatMulOpBroadcast2(TestMatMulV2Op):
""" """
case 14_4 case 14_4
""" """
...@@ -343,22 +359,90 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0): ...@@ -343,22 +359,90 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
create_test_fp16_class(TestMatMulV2Op) create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2) create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMuklOp3) create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMuklOp4) create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMuklOp5) create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMuklOp6) create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMuklOp7) create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMuklOp8) create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMuklOp9) create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMuklOp10) create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMuklOp11) create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMuklOp12) create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMuklOp13) create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMuklOp14) create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMuklOp15) create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMuklOp16) create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMuklOp17) create_test_fp16_class(TestMatMulOp17)
#--------------------test matmul bf16--------------------
def create_test_bf16_class(parent, atol=0.01):
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0"
)
class TestMatMulOpBf16Case(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(scope, self.op_type, self.inputs, self.outputs,
self.attrs)
return get_numeric_gradient(place, scope, op, self.inputs_fp32,
check_name, ['Out'])
def init_kernel_type(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def test_check_grad_x(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'X')
self.check_grad_with_place(
place, ['X'],
'Out',
no_grad_set=set(['Y']),
user_defined_grads=[numeric_grads])
def test_check_grad_y(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Y')
self.check_grad_with_place(
place, ['Y'],
'Out',
no_grad_set=set(['X']),
user_defined_grads=[numeric_grads])
def test_check_grad(self):
pass
cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
TestMatMulOpBf16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpBf16Case
create_test_bf16_class(TestMatMulV2Op)
create_test_bf16_class(TestMatMulOp2)
create_test_bf16_class(TestMatMulOp3)
create_test_bf16_class(TestMatMulOp4)
create_test_bf16_class(TestMatMulOp5)
create_test_bf16_class(TestMatMulOp6)
create_test_bf16_class(TestMatMulOp7)
create_test_bf16_class(TestMatMulOp8)
create_test_bf16_class(TestMatMulOp9)
create_test_bf16_class(TestMatMulOp10)
create_test_bf16_class(TestMatMulOp11)
create_test_bf16_class(TestMatMulOp12)
create_test_bf16_class(TestMatMulOp13)
create_test_bf16_class(TestMatMulOp14)
create_test_bf16_class(TestMatMulOp15)
create_test_bf16_class(TestMatMulOp16)
create_test_bf16_class(TestMatMulOp17)
class TestMatMulV2API(unittest.TestCase): class TestMatMulV2API(unittest.TestCase):
......
...@@ -97,7 +97,7 @@ class TestMatMulV2Op(XPUOpTest): ...@@ -97,7 +97,7 @@ class TestMatMulV2Op(XPUOpTest):
self.check_grad_with_place(place, ['X', 'Y'], 'Out') self.check_grad_with_place(place, ['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op): class TestMatMulOp2(TestMatMulV2Op):
""" """
case 2 case 2
""" """
...@@ -109,7 +109,7 @@ class TestMatMuklOp2(TestMatMulV2Op): ...@@ -109,7 +109,7 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp3(TestMatMulV2Op): class TestMatMulOp3(TestMatMulV2Op):
""" """
case 3 case 3
""" """
...@@ -121,7 +121,7 @@ class TestMatMuklOp3(TestMatMulV2Op): ...@@ -121,7 +121,7 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op): class TestMatMulOp4(TestMatMulV2Op):
""" """
case 4 case 4
""" """
...@@ -133,7 +133,7 @@ class TestMatMuklOp4(TestMatMulV2Op): ...@@ -133,7 +133,7 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op): class TestMatMulOp5(TestMatMulV2Op):
""" """
case 5 case 5
""" """
...@@ -145,7 +145,7 @@ class TestMatMuklOp5(TestMatMulV2Op): ...@@ -145,7 +145,7 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op): class TestMatMulOp6(TestMatMulV2Op):
""" """
case 6 case 6
""" """
...@@ -157,7 +157,7 @@ class TestMatMuklOp6(TestMatMulV2Op): ...@@ -157,7 +157,7 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op): class TestMatMulOp7(TestMatMulV2Op):
""" """
case 7 case 7
""" """
...@@ -169,7 +169,7 @@ class TestMatMuklOp7(TestMatMulV2Op): ...@@ -169,7 +169,7 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op): class TestMatMulOp8(TestMatMulV2Op):
""" """
case 8 case 8
""" """
...@@ -181,7 +181,7 @@ class TestMatMuklOp8(TestMatMulV2Op): ...@@ -181,7 +181,7 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op): class TestMatMulOp9(TestMatMulV2Op):
""" """
case 9 case 9
""" """
...@@ -193,7 +193,7 @@ class TestMatMuklOp9(TestMatMulV2Op): ...@@ -193,7 +193,7 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op): class TestMatMulOp10(TestMatMulV2Op):
""" """
case 10 case 10
""" """
...@@ -205,7 +205,7 @@ class TestMatMuklOp10(TestMatMulV2Op): ...@@ -205,7 +205,7 @@ class TestMatMuklOp10(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op): class TestMatMulOp11(TestMatMulV2Op):
""" """
case 11 case 11
""" """
...@@ -217,7 +217,7 @@ class TestMatMuklOp11(TestMatMulV2Op): ...@@ -217,7 +217,7 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op): class TestMatMulOp12(TestMatMulV2Op):
""" """
case 12 case 12
""" """
...@@ -229,7 +229,7 @@ class TestMatMuklOp12(TestMatMulV2Op): ...@@ -229,7 +229,7 @@ class TestMatMuklOp12(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op): class TestMatMulOp13(TestMatMulV2Op):
""" """
case 13 case 13
""" """
...@@ -241,7 +241,7 @@ class TestMatMuklOp13(TestMatMulV2Op): ...@@ -241,7 +241,7 @@ class TestMatMuklOp13(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op): class TestMatMulOp14(TestMatMulV2Op):
""" """
case 14_1 case 14_1
""" """
...@@ -253,7 +253,7 @@ class TestMatMuklOp14(TestMatMulV2Op): ...@@ -253,7 +253,7 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op): class TestMatMulOp15(TestMatMulV2Op):
""" """
case 14_2 case 14_2
""" """
...@@ -265,7 +265,7 @@ class TestMatMuklOp15(TestMatMulV2Op): ...@@ -265,7 +265,7 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.trans_y = True self.trans_y = True
class TestMatMuklOp16(TestMatMulV2Op): class TestMatMulOp16(TestMatMulV2Op):
""" """
case 16 : to check the big data case 16 : to check the big data
""" """
...@@ -277,7 +277,7 @@ class TestMatMuklOp16(TestMatMulV2Op): ...@@ -277,7 +277,7 @@ class TestMatMuklOp16(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op): class TestMatMulOp17(TestMatMulV2Op):
""" """
case 17 : to check the gradient for special case case 17 : to check the gradient for special case
""" """
...@@ -289,7 +289,7 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -289,7 +289,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
# class TestMatMuklOpBroadcast1(TestMatMulV2Op): # class TestMatMulOpBroadcast1(TestMatMulV2Op):
# """ # """
# case 14_3 # case 14_3
# """ # """
...@@ -300,7 +300,7 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -300,7 +300,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
# self.trans_x = True # self.trans_x = True
# self.trans_y = True # self.trans_y = True
# class TestMatMuklOpBroadcast2(TestMatMulV2Op): # class TestMatMulOpBroadcast2(TestMatMulV2Op):
# """ # """
# case 14_4 # case 14_4
# """ # """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册