未验证 提交 d5a0d31a 编写于 作者: L Leo Chen 提交者: GitHub

[bf16] pten matmul cuda kernel support bf16 (#39485)

* pten matmul cuda kernel support bf16

* fix pten kernel name

* add matmul_grad bf16 kernel

* add emptylike bf16 kernel

* fix compile

* suppport rocm

* fix error

* fix rocm

* add bf16 header file

* fix compile
上级 f31c2426
......@@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
}
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(op_proto_->type(), GetInputArgsNames(),
GetAttrsArgsNames(), GetOutputArgsNames());
return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}
std::once_flag kernel_sig_map_init_flag;
......
......@@ -813,6 +813,102 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 8000
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
......@@ -1208,6 +1304,42 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
}
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
......@@ -1306,6 +1438,91 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 9010
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
......@@ -1356,6 +1573,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
}
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
......
......@@ -550,6 +550,84 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
rocblas_datatype_f16_r, N, rocblas_datatype_f32_r);
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(zhiqiu): 80 has the same meaning for rocm and cuda?
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta,
C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(zhiqiu): 80 has the same meaning for rocm and cuda?
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, A, rocblas_datatype_bf16_r, lda, &h_beta,
C, rocblas_datatype_bf16_r, N, C, rocblas_datatype_bf16_r, N,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
......@@ -874,6 +952,39 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
}
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because rocblas doesn't support bfloat16 gemv, we use gemmex to achieve it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
......@@ -898,6 +1009,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
ldc, strideC, batchCount);
});
}
template <>
template <typename T>
void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
......@@ -925,6 +1037,70 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
});
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
const int64_t strideC = M * N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::rocblas_gemm_strided_batched_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r,
lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C,
rocblas_datatype_bf16_r, ldc, strideC, batchCount,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
const int64_t strideC = M * N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::rocblas_gemm_strided_batched_ex(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_bf16_r, ldb, strideB, A, rocblas_datatype_bf16_r,
lda, strideA, &h_beta, C, rocblas_datatype_bf16_r, ldc, strideC, C,
rocblas_datatype_bf16_r, ldc, strideC, batchCount,
rocblas_datatype_f32_r, algo, 0, 0));
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
......@@ -935,6 +1111,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
C[k]);
}
}
template <>
template <typename T>
void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
......@@ -973,6 +1150,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
}
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
......
......@@ -16,8 +16,10 @@ limitations under the License. */
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/enforce.h"
namespace paddle {
namespace platform {
......@@ -61,6 +63,19 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ bfloat16 CudaShuffleDownSync(unsigned mask,
bfloat16 val, int delta,
int width) {
#if defined(PADDLE_CUDA_BF16)
return bfloat16(__shfl_down_sync(mask, static_cast<nv_bfloat16>(val),
static_cast<unsigned>(delta), width));
#else
PADDLE_ENFORCE(
false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11.");
#endif
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -59,6 +60,14 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ bfloat16 CudaShuffleDownSync(unsigned mask,
bfloat16 val, int delta,
int width) {
return bfloat16(__shfl_down(static_cast<float>(val),
static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
......
......@@ -94,6 +94,7 @@ PT_REGISTER_KERNEL(empty_like,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#endif
......@@ -26,6 +26,7 @@ PT_REGISTER_KERNEL(matmul_grad,
float,
double,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......
......@@ -27,5 +27,6 @@ PT_REGISTER_KERNEL(matmul,
float,
double,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -71,7 +71,7 @@ class TestMatMulV2Op(OpTest):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op):
class TestMatMulOp2(TestMatMulV2Op):
"""
case 2
"""
......@@ -83,7 +83,7 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOp3(TestMatMulV2Op):
class TestMatMulOp3(TestMatMulV2Op):
"""
case 3
"""
......@@ -95,7 +95,7 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op):
class TestMatMulOp4(TestMatMulV2Op):
"""
case 4
"""
......@@ -107,7 +107,7 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op):
class TestMatMulOp5(TestMatMulV2Op):
"""
case 5
"""
......@@ -119,7 +119,7 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op):
class TestMatMulOp6(TestMatMulV2Op):
"""
case 6
"""
......@@ -131,7 +131,7 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op):
class TestMatMulOp7(TestMatMulV2Op):
"""
case 7
"""
......@@ -143,7 +143,7 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op):
class TestMatMulOp8(TestMatMulV2Op):
"""
case 8
"""
......@@ -155,7 +155,7 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op):
class TestMatMulOp9(TestMatMulV2Op):
"""
case 9
"""
......@@ -167,7 +167,7 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op):
class TestMatMulOp10(TestMatMulV2Op):
"""
case 10
"""
......@@ -179,7 +179,7 @@ class TestMatMuklOp10(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op):
class TestMatMulOp11(TestMatMulV2Op):
"""
case 11
"""
......@@ -191,7 +191,7 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op):
class TestMatMulOp12(TestMatMulV2Op):
"""
case 12
"""
......@@ -203,7 +203,7 @@ class TestMatMuklOp12(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op):
class TestMatMulOp13(TestMatMulV2Op):
"""
case 13
"""
......@@ -215,7 +215,7 @@ class TestMatMuklOp13(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op):
class TestMatMulOp14(TestMatMulV2Op):
"""
case 14_1
"""
......@@ -227,7 +227,7 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op):
class TestMatMulOp15(TestMatMulV2Op):
"""
case 14_2
"""
......@@ -239,7 +239,7 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp16(TestMatMulV2Op):
class TestMatMulOp16(TestMatMulV2Op):
"""
case 16 : to check the gradient for special case
"""
......@@ -251,7 +251,7 @@ class TestMatMuklOp16(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op):
class TestMatMulOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
"""
......@@ -263,7 +263,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
class TestMatMulOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
......@@ -275,7 +275,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
class TestMatMulOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
......@@ -310,22 +310,22 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5):
create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2)
create_test_fp16_class(TestMatMuklOp3)
create_test_fp16_class(TestMatMuklOp4)
create_test_fp16_class(TestMatMuklOp5)
create_test_fp16_class(TestMatMuklOp6)
create_test_fp16_class(TestMatMuklOp7)
create_test_fp16_class(TestMatMuklOp8)
create_test_fp16_class(TestMatMuklOp9)
create_test_fp16_class(TestMatMuklOp10)
create_test_fp16_class(TestMatMuklOp11)
create_test_fp16_class(TestMatMuklOp12)
create_test_fp16_class(TestMatMuklOp13)
create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17)
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMulOp17)
class TestMatMulV2API(unittest.TestCase):
......
......@@ -1658,7 +1658,7 @@ class OpTest(unittest.TestCase):
for grad in analytic_grads:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error
max_relative_error = 0.04 if max_relative_error < 0.04 else max_relative_error
fp32_analytic_grads.append(grad)
analytic_grads = fp32_analytic_grads
......@@ -1666,7 +1666,7 @@ class OpTest(unittest.TestCase):
for grad in numeric_grads:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error
max_relative_error = 0.04 if max_relative_error < 0.04 else max_relative_error
fp32_numeric_grads.append(grad)
numeric_grads = fp32_numeric_grads
......
......@@ -16,7 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient
from paddle.fluid.tests.unittests.testsuite import create_op
import paddle.fluid.core as core
import paddle
......@@ -73,17 +74,32 @@ class TestMatMulV2Op(OpTest):
self.init_kernel_type()
self.config()
self.op_type = "matmul_v2"
x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
x = -0.1 + 0.2 * x
y = -0.1 + 0.2 * y
if self.is_bfloat16_op():
x = np.random.random(self.x_shape).astype(np.float32)
y = np.random.random(self.y_shape).astype(np.float32)
else:
x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
x = -0.1 + 0.2 * x
y = -0.1 + 0.2 * y
result = reference_matmul(x, y, self.trans_x, self.trans_y)
result = result.astype(self.dtype)
self.inputs = {
'X': x,
'Y': y,
}
if self.is_bfloat16_op():
result = result.astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.inputs_fp32 = {
'X': x,
'Y': y,
}
else:
result = result.astype(self.dtype)
self.inputs = {
'X': x,
'Y': y,
}
self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
self.outputs = {'Out': result}
......@@ -97,7 +113,7 @@ class TestMatMulV2Op(OpTest):
self.check_grad(['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op):
class TestMatMulOp2(TestMatMulV2Op):
"""
case 2
"""
......@@ -109,7 +125,7 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOp3(TestMatMulV2Op):
class TestMatMulOp3(TestMatMulV2Op):
"""
case 3
"""
......@@ -121,7 +137,7 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op):
class TestMatMulOp4(TestMatMulV2Op):
"""
case 4
"""
......@@ -133,7 +149,7 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op):
class TestMatMulOp5(TestMatMulV2Op):
"""
case 5
"""
......@@ -145,7 +161,7 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op):
class TestMatMulOp6(TestMatMulV2Op):
"""
case 6
"""
......@@ -157,7 +173,7 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op):
class TestMatMulOp7(TestMatMulV2Op):
"""
case 7
"""
......@@ -169,7 +185,7 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op):
class TestMatMulOp8(TestMatMulV2Op):
"""
case 8
"""
......@@ -181,7 +197,7 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op):
class TestMatMulOp9(TestMatMulV2Op):
"""
case 9
"""
......@@ -193,7 +209,7 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op):
class TestMatMulOp10(TestMatMulV2Op):
"""
case 10
"""
......@@ -205,7 +221,7 @@ class TestMatMuklOp10(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op):
class TestMatMulOp11(TestMatMulV2Op):
"""
case 11
"""
......@@ -217,7 +233,7 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op):
class TestMatMulOp12(TestMatMulV2Op):
"""
case 12
"""
......@@ -229,7 +245,7 @@ class TestMatMuklOp12(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op):
class TestMatMulOp13(TestMatMulV2Op):
"""
case 13
"""
......@@ -241,7 +257,7 @@ class TestMatMuklOp13(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op):
class TestMatMulOp14(TestMatMulV2Op):
"""
case 14_1
"""
......@@ -253,7 +269,7 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op):
class TestMatMulOp15(TestMatMulV2Op):
"""
case 14_2
"""
......@@ -265,7 +281,7 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp16(TestMatMulV2Op):
class TestMatMulOp16(TestMatMulV2Op):
"""
case 16 : to check the gradient for special case
"""
......@@ -277,7 +293,7 @@ class TestMatMuklOp16(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op):
class TestMatMulOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
"""
......@@ -289,7 +305,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
class TestMatMulOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
......@@ -301,7 +317,7 @@ class TestMatMuklOpBroadcast1(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
class TestMatMulOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
......@@ -343,22 +359,90 @@ def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2)
create_test_fp16_class(TestMatMuklOp3)
create_test_fp16_class(TestMatMuklOp4)
create_test_fp16_class(TestMatMuklOp5)
create_test_fp16_class(TestMatMuklOp6)
create_test_fp16_class(TestMatMuklOp7)
create_test_fp16_class(TestMatMuklOp8)
create_test_fp16_class(TestMatMuklOp9)
create_test_fp16_class(TestMatMuklOp10)
create_test_fp16_class(TestMatMuklOp11)
create_test_fp16_class(TestMatMuklOp12)
create_test_fp16_class(TestMatMuklOp13)
create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17)
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp7)
create_test_fp16_class(TestMatMulOp8)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
create_test_fp16_class(TestMatMulOp14)
create_test_fp16_class(TestMatMulOp15)
create_test_fp16_class(TestMatMulOp16)
create_test_fp16_class(TestMatMulOp17)
#--------------------test matmul bf16--------------------
def create_test_bf16_class(parent, atol=0.01):
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0"
)
class TestMatMulOpBf16Case(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(scope, self.op_type, self.inputs, self.outputs,
self.attrs)
return get_numeric_gradient(place, scope, op, self.inputs_fp32,
check_name, ['Out'])
def init_kernel_type(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def test_check_grad_x(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'X')
self.check_grad_with_place(
place, ['X'],
'Out',
no_grad_set=set(['Y']),
user_defined_grads=[numeric_grads])
def test_check_grad_y(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Y')
self.check_grad_with_place(
place, ['Y'],
'Out',
no_grad_set=set(['X']),
user_defined_grads=[numeric_grads])
def test_check_grad(self):
pass
cls_name = "{0}_{1}".format(parent.__name__, "Bf16")
TestMatMulOpBf16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpBf16Case
create_test_bf16_class(TestMatMulV2Op)
create_test_bf16_class(TestMatMulOp2)
create_test_bf16_class(TestMatMulOp3)
create_test_bf16_class(TestMatMulOp4)
create_test_bf16_class(TestMatMulOp5)
create_test_bf16_class(TestMatMulOp6)
create_test_bf16_class(TestMatMulOp7)
create_test_bf16_class(TestMatMulOp8)
create_test_bf16_class(TestMatMulOp9)
create_test_bf16_class(TestMatMulOp10)
create_test_bf16_class(TestMatMulOp11)
create_test_bf16_class(TestMatMulOp12)
create_test_bf16_class(TestMatMulOp13)
create_test_bf16_class(TestMatMulOp14)
create_test_bf16_class(TestMatMulOp15)
create_test_bf16_class(TestMatMulOp16)
create_test_bf16_class(TestMatMulOp17)
class TestMatMulV2API(unittest.TestCase):
......
......@@ -97,7 +97,7 @@ class TestMatMulV2Op(XPUOpTest):
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
class TestMatMuklOp2(TestMatMulV2Op):
class TestMatMulOp2(TestMatMulV2Op):
"""
case 2
"""
......@@ -109,7 +109,7 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp3(TestMatMulV2Op):
class TestMatMulOp3(TestMatMulV2Op):
"""
case 3
"""
......@@ -121,7 +121,7 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp4(TestMatMulV2Op):
class TestMatMulOp4(TestMatMulV2Op):
"""
case 4
"""
......@@ -133,7 +133,7 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp5(TestMatMulV2Op):
class TestMatMulOp5(TestMatMulV2Op):
"""
case 5
"""
......@@ -145,7 +145,7 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp6(TestMatMulV2Op):
class TestMatMulOp6(TestMatMulV2Op):
"""
case 6
"""
......@@ -157,7 +157,7 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp7(TestMatMulV2Op):
class TestMatMulOp7(TestMatMulV2Op):
"""
case 7
"""
......@@ -169,7 +169,7 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp8(TestMatMulV2Op):
class TestMatMulOp8(TestMatMulV2Op):
"""
case 8
"""
......@@ -181,7 +181,7 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp9(TestMatMulV2Op):
class TestMatMulOp9(TestMatMulV2Op):
"""
case 9
"""
......@@ -193,7 +193,7 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOp10(TestMatMulV2Op):
class TestMatMulOp10(TestMatMulV2Op):
"""
case 10
"""
......@@ -205,7 +205,7 @@ class TestMatMuklOp10(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp11(TestMatMulV2Op):
class TestMatMulOp11(TestMatMulV2Op):
"""
case 11
"""
......@@ -217,7 +217,7 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp12(TestMatMulV2Op):
class TestMatMulOp12(TestMatMulV2Op):
"""
case 12
"""
......@@ -229,7 +229,7 @@ class TestMatMuklOp12(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op):
class TestMatMulOp13(TestMatMulV2Op):
"""
case 13
"""
......@@ -241,7 +241,7 @@ class TestMatMuklOp13(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp14(TestMatMulV2Op):
class TestMatMulOp14(TestMatMulV2Op):
"""
case 14_1
"""
......@@ -253,7 +253,7 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op):
class TestMatMulOp15(TestMatMulV2Op):
"""
case 14_2
"""
......@@ -265,7 +265,7 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.trans_y = True
class TestMatMuklOp16(TestMatMulV2Op):
class TestMatMulOp16(TestMatMulV2Op):
"""
case 16 : to check the big data
"""
......@@ -277,7 +277,7 @@ class TestMatMuklOp16(TestMatMulV2Op):
self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op):
class TestMatMulOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
"""
......@@ -289,7 +289,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.trans_y = False
# class TestMatMuklOpBroadcast1(TestMatMulV2Op):
# class TestMatMulOpBroadcast1(TestMatMulV2Op):
# """
# case 14_3
# """
......@@ -300,7 +300,7 @@ class TestMatMuklOp17(TestMatMulV2Op):
# self.trans_x = True
# self.trans_y = True
# class TestMatMuklOpBroadcast2(TestMatMulV2Op):
# class TestMatMulOpBroadcast2(TestMatMulV2Op):
# """
# case 14_4
# """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册