diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index ea2e62d89f63da2bfe7e49c34e8aecad4e6138e0..2d2cc30497e288046256af5564620d40913cf3bf 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { } KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { - return KernelSignature(op_proto_->type(), GetInputArgsNames(), - GetAttrsArgsNames(), GetOutputArgsNames()); + return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()), + GetInputArgsNames(), GetAttrsArgsNames(), + GetOutputArgsNames()); } std::once_flag kernel_sig_map_init_flag; diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index f9a4e963c0c478e2d4e4bb35b2ddf63e0ac7e8b8..0e6b63be90ef695801c8dc820985d3562ab429ae 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -813,6 +813,102 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 8000 } +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 *A, + const platform::bfloat16 *B, platform::bfloat16 beta, + platform::bfloat16 *C) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), 80, + platform::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A, + CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo)); + }); +#else + // raise error + PADDLE_THROW(platform::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, + int K, platform::bfloat16 alpha, + const platform::bfloat16 *A, + const platform::bfloat16 *B, + platform::bfloat16 beta, + platform::bfloat16 *C) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), 80, + platform::errors::InvalidArgument( + "cublas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A, + CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo)); + }); +#else + // raise error + PADDLE_THROW(platform::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + template <> template <> inline void Blas::GEMM( @@ -1208,6 +1304,42 @@ inline void Blas::GEMV(bool trans_a, int M, int N, } } +template <> +template <> +inline void Blas::GEMV( + bool trans_a, int M, int N, platform::bfloat16 alpha, + const platform::bfloat16 *A, const platform::bfloat16 *B, + platform::bfloat16 beta, platform::bfloat16 *C) const { + // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve + // it. + if (trans_a) { + this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, + alpha, B, A, beta, C); + } else { + this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, + alpha, A, B, beta, C); + } +} + +template <> +template <> +inline void Blas::GEMV(bool trans_a, int M, int N, + platform::bfloat16 alpha, + const platform::bfloat16 *A, + const platform::bfloat16 *B, + platform::bfloat16 beta, + platform::bfloat16 *C) const { + // Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve + // it. + if (trans_a) { + this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, + alpha, B, A, beta, C); + } else { + this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, + alpha, A, B, beta, C); + } +} + template <> template void Blas::BatchedGEMM( @@ -1306,6 +1438,91 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 9010 } +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 *A, + const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, + int batchCount, int64_t strideA, int64_t strideB) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, + strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc, + strideC, batchCount, CUBLAS_COMPUTE_32F, algo)); + }); +#else + // raise error + PADDLE_THROW(platform::errors::Unimplemented( + "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " + "11")); +#endif // CUDA_VERSION >= 11000 +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 *A, + const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, + int batchCount, int64_t strideA, int64_t strideB) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, + strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc, + strideC, batchCount, CUBLAS_COMPUTE_32F, algo)); + }); +#else + // raise error + PADDLE_THROW(platform::errors::Unimplemented( + "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " + "11")); +#endif // CUDA_VERSION >= 11000 +} + template <> template void Blas::BatchedGEMM( @@ -1356,6 +1573,32 @@ inline void Blas::BatchedGEMM( } } +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 **A, + const platform::bfloat16 **B, platform::bfloat16 beta, + platform::bfloat16 **C, int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, + A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 **A, + const platform::bfloat16 **B, platform::bfloat16 beta, + platform::bfloat16 **C, int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, + A[k], B[k], beta, C[k]); + } +} + template <> template void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, diff --git a/paddle/fluid/operators/math/blas_impl.hip.h b/paddle/fluid/operators/math/blas_impl.hip.h index 980caa9cfe68c64a1afd21a82d366b5228f8f026..9518da89edeb01a1dc35c2a6544ff2e55297a697 100644 --- a/paddle/fluid/operators/math/blas_impl.hip.h +++ b/paddle/fluid/operators/math/blas_impl.hip.h @@ -550,6 +550,84 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, rocblas_datatype_f16_r, N, rocblas_datatype_f32_r); } +template <> +template <> +inline void Blas::GEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 *A, + const platform::bfloat16 *B, platform::bfloat16 beta, + platform::bfloat16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + // TODO(zhiqiu): 80 has the same meaning for rocm and cuda? + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), 80, + platform::errors::InvalidArgument( + "rocblas fp16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, + rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta, + C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N, + rocblas_datatype_f32_r, algo, 0, 0)); + }); +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, int M, int N, + int K, platform::bfloat16 alpha, + const platform::bfloat16 *A, + const platform::bfloat16 *B, + platform::bfloat16 beta, + platform::bfloat16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + // TODO(zhiqiu): 80 has the same meaning for rocm and cuda? + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), 80, + platform::errors::InvalidArgument( + "rocblas fp16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, + rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta, + C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N, + rocblas_datatype_f32_r, algo, 0, 0)); + }); +} + template <> template <> inline void Blas::GEMM( @@ -874,6 +952,39 @@ inline void Blas::GEMV(bool trans_a, int M, int N, } } +template <> +template <> +inline void Blas::GEMV( + bool trans_a, int M, int N, platform::bfloat16 alpha, + const platform::bfloat16 *A, const platform::bfloat16 *B, + platform::bfloat16 beta, platform::bfloat16 *C) const { + // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it. + if (trans_a) { + this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, + alpha, B, A, beta, C); + } else { + this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, + alpha, A, B, beta, C); + } +} +template <> +template <> +inline void Blas::GEMV(bool trans_a, int M, int N, + platform::bfloat16 alpha, + const platform::bfloat16 *A, + const platform::bfloat16 *B, + platform::bfloat16 beta, + platform::bfloat16 *C) const { + // Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it. + if (trans_a) { + this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, + alpha, B, A, beta, C); + } else { + this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, + alpha, A, B, beta, C); + } +} + template <> template void Blas::BatchedGEMM( @@ -898,6 +1009,7 @@ void Blas::BatchedGEMM( ldc, strideC, batchCount); }); } + template <> template void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, @@ -925,6 +1037,70 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, }); } +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 *A, + const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, + int batchCount, int64_t strideA, int64_t strideB) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + const int64_t strideC = M * N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::rocblas_gemm_strided_batched_ex( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, + rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r, + lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C, + rocblas_datatype_bf16_r, ldc, strideC, batchCount, + rocblas_datatype_f32_r, algo, 0, 0)); + }); +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 *A, + const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C, + int batchCount, int64_t strideA, int64_t strideB) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + const int64_t strideC = M * N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::rocblas_gemm_strided_batched_ex( + handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, + rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r, + lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C, + rocblas_datatype_bf16_r, ldc, strideC, batchCount, + rocblas_datatype_f32_r, algo, 0, 0)); + }); +} + template <> template void Blas::BatchedGEMM( @@ -935,6 +1111,7 @@ void Blas::BatchedGEMM( C[k]); } } + template <> template void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, @@ -973,6 +1150,32 @@ inline void Blas::BatchedGEMM( } } +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 **A, + const platform::bfloat16 **B, platform::bfloat16 beta, + platform::bfloat16 **C, int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, + A[k], B[k], beta, C[k]); + } +} + +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::bfloat16 alpha, const platform::bfloat16 **A, + const platform::bfloat16 **B, platform::bfloat16 beta, + platform::bfloat16 **C, int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, + A[k], B[k], beta, C[k]); + } +} + template <> template void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h b/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h index cd78a89088cc612c3fb43e489cfb7ef2e07cfcf3..58a25ae8d0e565b649b29863637fa9d000d524d3 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h @@ -16,8 +16,10 @@ limitations under the License. */ // NOTE(): support float16 to half in header file. #define PADDLE_CUDA_FP16 +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/core/enforce.h" namespace paddle { namespace platform { @@ -61,6 +63,19 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, static_cast(delta), width)); } +template <> +__forceinline__ __device__ bfloat16 CudaShuffleDownSync(unsigned mask, + bfloat16 val, int delta, + int width) { +#if defined(PADDLE_CUDA_BF16) + return bfloat16(__shfl_down_sync(mask, static_cast(val), + static_cast(delta), width)); +#else + PADDLE_ENFORCE( + false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11."); +#endif +} + template <> __forceinline__ __device__ paddle::platform::complex CudaShuffleDownSync( unsigned mask, paddle::platform::complex val, int delta, int width) { diff --git a/paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h b/paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h index 13ffc2396946c5819c9276cf474d96a8057c4094..63897bd6717408bff4bd4db5e739b3ba64316350 100644 --- a/paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h +++ b/paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h @@ -16,6 +16,7 @@ limitations under the License. */ // NOTE(): support float16 to half in header file. #define PADDLE_CUDA_FP16 +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -59,6 +60,14 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, static_cast(delta), width)); } +template <> +__forceinline__ __device__ bfloat16 CudaShuffleDownSync(unsigned mask, + bfloat16 val, int delta, + int width) { + return bfloat16(__shfl_down(static_cast(val), + static_cast(delta), width)); +} + template <> __forceinline__ __device__ paddle::platform::complex CudaShuffleDownSync( unsigned mask, paddle::platform::complex val, int delta, int width) { diff --git a/paddle/pten/kernels/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc index ecb058d35b909bc9455b019e55ab8f2277fd587b..e1a1788815ebfef75ac29e332da3e76f3d2a5d52 100644 --- a/paddle/pten/kernels/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -94,6 +94,7 @@ PT_REGISTER_KERNEL(empty_like, int64_t, bool, paddle::platform::float16, + paddle::platform::bfloat16, paddle::platform::complex, paddle::platform::complex) {} #endif diff --git a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu index 31c44673f94e737bd94882b2537ddf3fababf226..7df99260aa1614a29325ed1d0834400566e28139 100644 --- a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu +++ b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu @@ -26,6 +26,7 @@ PT_REGISTER_KERNEL(matmul_grad, float, double, paddle::platform::float16, + paddle::platform::bfloat16, paddle::platform::complex, paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/matmul_kernel.cu b/paddle/pten/kernels/gpu/matmul_kernel.cu index f9fdbd27bf94e4b236efe5a49e471e39c4c57dd5..b365581e949c103be511e4849a45b4fd9a024f77 100644 --- a/paddle/pten/kernels/gpu/matmul_kernel.cu +++ b/paddle/pten/kernels/gpu/matmul_kernel.cu @@ -27,5 +27,6 @@ PT_REGISTER_KERNEL(matmul, float, double, paddle::platform::float16, + paddle::platform::bfloat16, paddle::platform::complex, paddle::platform::complex) {} diff --git a/python/paddle/fluid/tests/unittests/npu/test_matmulv2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_matmulv2_op_npu.py index 882043ef6eb911f6163d516e9929658f38810ade..23ca0cf1f492fade05a81f0de1d6bc262458675c 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_matmulv2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_matmulv2_op_npu.py @@ -71,7 +71,7 @@ class TestMatMulV2Op(OpTest): self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') -class TestMatMuklOp2(TestMatMulV2Op): +class TestMatMulOp2(TestMatMulV2Op): """ case 2 """ @@ -83,7 +83,7 @@ class TestMatMuklOp2(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOp3(TestMatMulV2Op): +class TestMatMulOp3(TestMatMulV2Op): """ case 3 """ @@ -95,7 +95,7 @@ class TestMatMuklOp3(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp4(TestMatMulV2Op): +class TestMatMulOp4(TestMatMulV2Op): """ case 4 """ @@ -107,7 +107,7 @@ class TestMatMuklOp4(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp5(TestMatMulV2Op): +class TestMatMulOp5(TestMatMulV2Op): """ case 5 """ @@ -119,7 +119,7 @@ class TestMatMuklOp5(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp6(TestMatMulV2Op): +class TestMatMulOp6(TestMatMulV2Op): """ case 6 """ @@ -131,7 +131,7 @@ class TestMatMuklOp6(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp7(TestMatMulV2Op): +class TestMatMulOp7(TestMatMulV2Op): """ case 7 """ @@ -143,7 +143,7 @@ class TestMatMuklOp7(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp8(TestMatMulV2Op): +class TestMatMulOp8(TestMatMulV2Op): """ case 8 """ @@ -155,7 +155,7 @@ class TestMatMuklOp8(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp9(TestMatMulV2Op): +class TestMatMulOp9(TestMatMulV2Op): """ case 9 """ @@ -167,7 +167,7 @@ class TestMatMuklOp9(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOp10(TestMatMulV2Op): +class TestMatMulOp10(TestMatMulV2Op): """ case 10 """ @@ -179,7 +179,7 @@ class TestMatMuklOp10(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp11(TestMatMulV2Op): +class TestMatMulOp11(TestMatMulV2Op): """ case 11 """ @@ -191,7 +191,7 @@ class TestMatMuklOp11(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp12(TestMatMulV2Op): +class TestMatMulOp12(TestMatMulV2Op): """ case 12 """ @@ -203,7 +203,7 @@ class TestMatMuklOp12(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp13(TestMatMulV2Op): +class TestMatMulOp13(TestMatMulV2Op): """ case 13 """ @@ -215,7 +215,7 @@ class TestMatMuklOp13(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp14(TestMatMulV2Op): +class TestMatMulOp14(TestMatMulV2Op): """ case 14_1 """ @@ -227,7 +227,7 @@ class TestMatMuklOp14(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp15(TestMatMulV2Op): +class TestMatMulOp15(TestMatMulV2Op): """ case 14_2 """ @@ -239,7 +239,7 @@ class TestMatMuklOp15(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp16(TestMatMulV2Op): +class TestMatMulOp16(TestMatMulV2Op): """ case 16 : to check the gradient for special case """ @@ -251,7 +251,7 @@ class TestMatMuklOp16(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp17(TestMatMulV2Op): +class TestMatMulOp17(TestMatMulV2Op): """ case 17 : to check the gradient for special case """ @@ -263,7 +263,7 @@ class TestMatMuklOp17(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOpBroadcast1(TestMatMulV2Op): +class TestMatMulOpBroadcast1(TestMatMulV2Op): """ case 14_3 """ @@ -275,7 +275,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOpBroadcast2(TestMatMulV2Op): +class TestMatMulOpBroadcast2(TestMatMulV2Op): """ case 14_4 """ @@ -310,22 +310,22 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5): create_test_fp16_class(TestMatMulV2Op) -create_test_fp16_class(TestMatMuklOp2) -create_test_fp16_class(TestMatMuklOp3) -create_test_fp16_class(TestMatMuklOp4) -create_test_fp16_class(TestMatMuklOp5) -create_test_fp16_class(TestMatMuklOp6) -create_test_fp16_class(TestMatMuklOp7) -create_test_fp16_class(TestMatMuklOp8) -create_test_fp16_class(TestMatMuklOp9) -create_test_fp16_class(TestMatMuklOp10) -create_test_fp16_class(TestMatMuklOp11) -create_test_fp16_class(TestMatMuklOp12) -create_test_fp16_class(TestMatMuklOp13) -create_test_fp16_class(TestMatMuklOp14) -create_test_fp16_class(TestMatMuklOp15) -create_test_fp16_class(TestMatMuklOp16) -create_test_fp16_class(TestMatMuklOp17) +create_test_fp16_class(TestMatMulOp2) +create_test_fp16_class(TestMatMulOp3) +create_test_fp16_class(TestMatMulOp4) +create_test_fp16_class(TestMatMulOp5) +create_test_fp16_class(TestMatMulOp6) +create_test_fp16_class(TestMatMulOp7) +create_test_fp16_class(TestMatMulOp8) +create_test_fp16_class(TestMatMulOp9) +create_test_fp16_class(TestMatMulOp10) +create_test_fp16_class(TestMatMulOp11) +create_test_fp16_class(TestMatMulOp12) +create_test_fp16_class(TestMatMulOp13) +create_test_fp16_class(TestMatMulOp14) +create_test_fp16_class(TestMatMulOp15) +create_test_fp16_class(TestMatMulOp16) +create_test_fp16_class(TestMatMulOp17) class TestMatMulV2API(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 754d7bd54b9f817d73c2f5d705026c9a468f4008..85423df3d382831738c2c64ea845d0661f9cdbb7 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1658,7 +1658,7 @@ class OpTest(unittest.TestCase): for grad in analytic_grads: if grad.dtype == np.uint16: grad = convert_uint16_to_float(grad) - max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error + max_relative_error = 0.04 if max_relative_error < 0.04 else max_relative_error fp32_analytic_grads.append(grad) analytic_grads = fp32_analytic_grads @@ -1666,7 +1666,7 @@ class OpTest(unittest.TestCase): for grad in numeric_grads: if grad.dtype == np.uint16: grad = convert_uint16_to_float(grad) - max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error + max_relative_error = 0.04 if max_relative_error < 0.04 else max_relative_error fp32_numeric_grads.append(grad) numeric_grads = fp32_numeric_grads diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index efcc0e4cfe323294df88167a6100f019cef67005..ed1495c6352bb979058d1dca015171f013fd38d9 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -16,7 +16,8 @@ from __future__ import print_function import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient +from paddle.fluid.tests.unittests.testsuite import create_op import paddle.fluid.core as core import paddle @@ -73,17 +74,32 @@ class TestMatMulV2Op(OpTest): self.init_kernel_type() self.config() self.op_type = "matmul_v2" - x = np.random.random(self.x_shape).astype(self.dtype) - y = np.random.random(self.y_shape).astype(self.dtype) - # -0.1 ~ 0.1 - x = -0.1 + 0.2 * x - y = -0.1 + 0.2 * y + if self.is_bfloat16_op(): + x = np.random.random(self.x_shape).astype(np.float32) + y = np.random.random(self.y_shape).astype(np.float32) + else: + x = np.random.random(self.x_shape).astype(self.dtype) + y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + x = -0.1 + 0.2 * x + y = -0.1 + 0.2 * y result = reference_matmul(x, y, self.trans_x, self.trans_y) - result = result.astype(self.dtype) - self.inputs = { - 'X': x, - 'Y': y, - } + if self.is_bfloat16_op(): + result = result.astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(x), + 'Y': convert_float_to_uint16(y), + } + self.inputs_fp32 = { + 'X': x, + 'Y': y, + } + else: + result = result.astype(self.dtype) + self.inputs = { + 'X': x, + 'Y': y, + } self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} self.outputs = {'Out': result} @@ -97,7 +113,7 @@ class TestMatMulV2Op(OpTest): self.check_grad(['X', 'Y'], 'Out') -class TestMatMuklOp2(TestMatMulV2Op): +class TestMatMulOp2(TestMatMulV2Op): """ case 2 """ @@ -109,7 +125,7 @@ class TestMatMuklOp2(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOp3(TestMatMulV2Op): +class TestMatMulOp3(TestMatMulV2Op): """ case 3 """ @@ -121,7 +137,7 @@ class TestMatMuklOp3(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp4(TestMatMulV2Op): +class TestMatMulOp4(TestMatMulV2Op): """ case 4 """ @@ -133,7 +149,7 @@ class TestMatMuklOp4(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp5(TestMatMulV2Op): +class TestMatMulOp5(TestMatMulV2Op): """ case 5 """ @@ -145,7 +161,7 @@ class TestMatMuklOp5(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp6(TestMatMulV2Op): +class TestMatMulOp6(TestMatMulV2Op): """ case 6 """ @@ -157,7 +173,7 @@ class TestMatMuklOp6(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp7(TestMatMulV2Op): +class TestMatMulOp7(TestMatMulV2Op): """ case 7 """ @@ -169,7 +185,7 @@ class TestMatMuklOp7(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp8(TestMatMulV2Op): +class TestMatMulOp8(TestMatMulV2Op): """ case 8 """ @@ -181,7 +197,7 @@ class TestMatMuklOp8(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp9(TestMatMulV2Op): +class TestMatMulOp9(TestMatMulV2Op): """ case 9 """ @@ -193,7 +209,7 @@ class TestMatMuklOp9(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOp10(TestMatMulV2Op): +class TestMatMulOp10(TestMatMulV2Op): """ case 10 """ @@ -205,7 +221,7 @@ class TestMatMuklOp10(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp11(TestMatMulV2Op): +class TestMatMulOp11(TestMatMulV2Op): """ case 11 """ @@ -217,7 +233,7 @@ class TestMatMuklOp11(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp12(TestMatMulV2Op): +class TestMatMulOp12(TestMatMulV2Op): """ case 12 """ @@ -229,7 +245,7 @@ class TestMatMuklOp12(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp13(TestMatMulV2Op): +class TestMatMulOp13(TestMatMulV2Op): """ case 13 """ @@ -241,7 +257,7 @@ class TestMatMuklOp13(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp14(TestMatMulV2Op): +class TestMatMulOp14(TestMatMulV2Op): """ case 14_1 """ @@ -253,7 +269,7 @@ class TestMatMuklOp14(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp15(TestMatMulV2Op): +class TestMatMulOp15(TestMatMulV2Op): """ case 14_2 """ @@ -265,7 +281,7 @@ class TestMatMuklOp15(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp16(TestMatMulV2Op): +class TestMatMulOp16(TestMatMulV2Op): """ case 16 : to check the gradient for special case """ @@ -277,7 +293,7 @@ class TestMatMuklOp16(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp17(TestMatMulV2Op): +class TestMatMulOp17(TestMatMulV2Op): """ case 17 : to check the gradient for special case """ @@ -289,7 +305,7 @@ class TestMatMuklOp17(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOpBroadcast1(TestMatMulV2Op): +class TestMatMulOpBroadcast1(TestMatMulV2Op): """ case 14_3 """ @@ -301,7 +317,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOpBroadcast2(TestMatMulV2Op): +class TestMatMulOpBroadcast2(TestMatMulV2Op): """ case 14_4 """ @@ -343,22 +359,90 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0): create_test_fp16_class(TestMatMulV2Op) -create_test_fp16_class(TestMatMuklOp2) -create_test_fp16_class(TestMatMuklOp3) -create_test_fp16_class(TestMatMuklOp4) -create_test_fp16_class(TestMatMuklOp5) -create_test_fp16_class(TestMatMuklOp6) -create_test_fp16_class(TestMatMuklOp7) -create_test_fp16_class(TestMatMuklOp8) -create_test_fp16_class(TestMatMuklOp9) -create_test_fp16_class(TestMatMuklOp10) -create_test_fp16_class(TestMatMuklOp11) -create_test_fp16_class(TestMatMuklOp12) -create_test_fp16_class(TestMatMuklOp13) -create_test_fp16_class(TestMatMuklOp14) -create_test_fp16_class(TestMatMuklOp15) -create_test_fp16_class(TestMatMuklOp16) -create_test_fp16_class(TestMatMuklOp17) +create_test_fp16_class(TestMatMulOp2) +create_test_fp16_class(TestMatMulOp3) +create_test_fp16_class(TestMatMulOp4) +create_test_fp16_class(TestMatMulOp5) +create_test_fp16_class(TestMatMulOp6) +create_test_fp16_class(TestMatMulOp7) +create_test_fp16_class(TestMatMulOp8) +create_test_fp16_class(TestMatMulOp9) +create_test_fp16_class(TestMatMulOp10) +create_test_fp16_class(TestMatMulOp11) +create_test_fp16_class(TestMatMulOp12) +create_test_fp16_class(TestMatMulOp13) +create_test_fp16_class(TestMatMulOp14) +create_test_fp16_class(TestMatMulOp15) +create_test_fp16_class(TestMatMulOp16) +create_test_fp16_class(TestMatMulOp17) + +#--------------------test matmul bf16-------------------- + + +def create_test_bf16_class(parent, atol=0.01): + @unittest.skipIf( + not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, + "core is not compiled with CUDA and cudnn version need larger than 8.1.0" + ) + class TestMatMulOpBf16Case(parent): + def get_numeric_grad(self, place, check_name): + scope = core.Scope() + self._check_grad_helper() + op = create_op(scope, self.op_type, self.inputs, self.outputs, + self.attrs) + return get_numeric_gradient(place, scope, op, self.inputs_fp32, + check_name, ['Out']) + + def init_kernel_type(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=atol) + + def test_check_grad_x(self): + place = core.CUDAPlace(0) + numeric_grads = self.get_numeric_grad(place, 'X') + self.check_grad_with_place( + place, ['X'], + 'Out', + no_grad_set=set(['Y']), + user_defined_grads=[numeric_grads]) + + def test_check_grad_y(self): + place = core.CUDAPlace(0) + numeric_grads = self.get_numeric_grad(place, 'Y') + self.check_grad_with_place( + place, ['Y'], + 'Out', + no_grad_set=set(['X']), + user_defined_grads=[numeric_grads]) + + def test_check_grad(self): + pass + + cls_name = "{0}_{1}".format(parent.__name__, "Bf16") + TestMatMulOpBf16Case.__name__ = cls_name + globals()[cls_name] = TestMatMulOpBf16Case + + +create_test_bf16_class(TestMatMulV2Op) +create_test_bf16_class(TestMatMulOp2) +create_test_bf16_class(TestMatMulOp3) +create_test_bf16_class(TestMatMulOp4) +create_test_bf16_class(TestMatMulOp5) +create_test_bf16_class(TestMatMulOp6) +create_test_bf16_class(TestMatMulOp7) +create_test_bf16_class(TestMatMulOp8) +create_test_bf16_class(TestMatMulOp9) +create_test_bf16_class(TestMatMulOp10) +create_test_bf16_class(TestMatMulOp11) +create_test_bf16_class(TestMatMulOp12) +create_test_bf16_class(TestMatMulOp13) +create_test_bf16_class(TestMatMulOp14) +create_test_bf16_class(TestMatMulOp15) +create_test_bf16_class(TestMatMulOp16) +create_test_bf16_class(TestMatMulOp17) class TestMatMulV2API(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py index 435026220c2b59a0f8df73f071673dab044e8348..45d60c8538e092f4c5d97f6525870af33a6ad9d5 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py @@ -97,7 +97,7 @@ class TestMatMulV2Op(XPUOpTest): self.check_grad_with_place(place, ['X', 'Y'], 'Out') -class TestMatMuklOp2(TestMatMulV2Op): +class TestMatMulOp2(TestMatMulV2Op): """ case 2 """ @@ -109,7 +109,7 @@ class TestMatMuklOp2(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp3(TestMatMulV2Op): +class TestMatMulOp3(TestMatMulV2Op): """ case 3 """ @@ -121,7 +121,7 @@ class TestMatMuklOp3(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp4(TestMatMulV2Op): +class TestMatMulOp4(TestMatMulV2Op): """ case 4 """ @@ -133,7 +133,7 @@ class TestMatMuklOp4(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp5(TestMatMulV2Op): +class TestMatMulOp5(TestMatMulV2Op): """ case 5 """ @@ -145,7 +145,7 @@ class TestMatMuklOp5(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp6(TestMatMulV2Op): +class TestMatMulOp6(TestMatMulV2Op): """ case 6 """ @@ -157,7 +157,7 @@ class TestMatMuklOp6(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp7(TestMatMulV2Op): +class TestMatMulOp7(TestMatMulV2Op): """ case 7 """ @@ -169,7 +169,7 @@ class TestMatMuklOp7(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp8(TestMatMulV2Op): +class TestMatMulOp8(TestMatMulV2Op): """ case 8 """ @@ -181,7 +181,7 @@ class TestMatMuklOp8(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp9(TestMatMulV2Op): +class TestMatMulOp9(TestMatMulV2Op): """ case 9 """ @@ -193,7 +193,7 @@ class TestMatMuklOp9(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOp10(TestMatMulV2Op): +class TestMatMulOp10(TestMatMulV2Op): """ case 10 """ @@ -205,7 +205,7 @@ class TestMatMuklOp10(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp11(TestMatMulV2Op): +class TestMatMulOp11(TestMatMulV2Op): """ case 11 """ @@ -217,7 +217,7 @@ class TestMatMuklOp11(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp12(TestMatMulV2Op): +class TestMatMulOp12(TestMatMulV2Op): """ case 12 """ @@ -229,7 +229,7 @@ class TestMatMuklOp12(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp13(TestMatMulV2Op): +class TestMatMulOp13(TestMatMulV2Op): """ case 13 """ @@ -241,7 +241,7 @@ class TestMatMuklOp13(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp14(TestMatMulV2Op): +class TestMatMulOp14(TestMatMulV2Op): """ case 14_1 """ @@ -253,7 +253,7 @@ class TestMatMuklOp14(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp15(TestMatMulV2Op): +class TestMatMulOp15(TestMatMulV2Op): """ case 14_2 """ @@ -265,7 +265,7 @@ class TestMatMuklOp15(TestMatMulV2Op): self.trans_y = True -class TestMatMuklOp16(TestMatMulV2Op): +class TestMatMulOp16(TestMatMulV2Op): """ case 16 : to check the big data """ @@ -277,7 +277,7 @@ class TestMatMuklOp16(TestMatMulV2Op): self.trans_y = False -class TestMatMuklOp17(TestMatMulV2Op): +class TestMatMulOp17(TestMatMulV2Op): """ case 17 : to check the gradient for special case """ @@ -289,7 +289,7 @@ class TestMatMuklOp17(TestMatMulV2Op): self.trans_y = False -# class TestMatMuklOpBroadcast1(TestMatMulV2Op): +# class TestMatMulOpBroadcast1(TestMatMulV2Op): # """ # case 14_3 # """ @@ -300,7 +300,7 @@ class TestMatMuklOp17(TestMatMulV2Op): # self.trans_x = True # self.trans_y = True -# class TestMatMuklOpBroadcast2(TestMatMulV2Op): +# class TestMatMulOpBroadcast2(TestMatMulV2Op): # """ # case 14_4 # """