未验证 提交 f36a9464 编写于 作者: L Leo Chen 提交者: GitHub

use fp32 compute type for cublasGemmStridedBatchedEx with fp16 input/output (#42851)

* use fp32 compute type for cublasGemmStridedBatchedEx with fp16 input/output

* add flags to control compute type

* default to false

* add unit test

* default to true
上级 4a48e3d1
...@@ -88,6 +88,21 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -88,6 +88,21 @@ PADDLE_DEFINE_EXPORTED_bool(
"input and output must be half precision) and recurrent neural networks " "input and output must be half precision) and recurrent neural networks "
"(RNNs)."); "(RNNs).");
/**
* CUDA related related FLAG
* Name: FLAGS_gemm_use_half_precision_compute_type
* Since Version: 2.4
* Value Range: bool, default=true
* Example:
* Note: whether to use fp16 compute type when the input and output is fp16,
* faster but it may loss precision.
*/
PADDLE_DEFINE_EXPORTED_bool(
gemm_use_half_precision_compute_type, true,
"Whether to use fp16 compute type when the input and output is fp16, "
"faster but it may loss precision in most case. If true, the compute "
"type will be set to fp32. Default is true.");
/** /**
* CUDA related FLAG * CUDA related FLAG
* Name: FLAGS_selected_gpus * Name: FLAGS_selected_gpus
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
DECLARE_bool(enable_cublas_tensor_op_math); DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_bool(gemm_use_half_precision_compute_type);
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -2255,8 +2256,25 @@ void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM( ...@@ -2255,8 +2256,25 @@ void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
} }
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
VLOG(4) << "use_half_precision_compute_type: "
<< FLAGS_gemm_use_half_precision_compute_type;
auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F; auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
cudaDataType_t compute_type = fp;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
void *a = static_cast<void *>(&h_alpha);
void *b = static_cast<void *>(&h_beta);
// set ComputeType as CUDA_R_32F for fp16, for better accuracy
if (FLAGS_gemm_use_half_precision_compute_type == true &&
std::is_same<T, phi::dtype::float16>::value) {
a = static_cast<void *>(&alpha);
b = static_cast<void *>(&beta);
compute_type = CUDA_R_16F;
}
// set ComputeType as CUDA_R_32F for fp16 and fp32, for better accuracy
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasGemmStridedBatchedEx(handle, paddle::platform::dynload::cublasGemmStridedBatchedEx(handle,
...@@ -2265,7 +2283,7 @@ void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM( ...@@ -2265,7 +2283,7 @@ void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
N, N,
M, M,
K, K,
&alpha, a,
B, B,
fp, fp,
ldb, ldb,
...@@ -2274,13 +2292,13 @@ void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM( ...@@ -2274,13 +2292,13 @@ void Blas<paddle::platform::CUDADeviceContext>::BatchedGEMM(
fp, fp,
lda, lda,
strideA, strideA,
&beta, b,
C, C,
fp, fp,
ldc, ldc,
strideC, strideC,
batchCount, batchCount,
fp, compute_type,
algo)); algo));
}); });
} else { } else {
...@@ -2348,8 +2366,24 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -2348,8 +2366,24 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
} }
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
VLOG(4) << "use_half_precision_compute_type: "
<< FLAGS_gemm_use_half_precision_compute_type;
auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F; auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
cudaDataType_t compute_type = CUDA_R_32F;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
void *a = static_cast<void *>(&h_alpha);
void *b = static_cast<void *>(&h_beta);
// set ComputeType as CUDA_R_32F for fp16, for better accuracy
if (FLAGS_gemm_use_half_precision_compute_type == true &&
std::is_same<T, phi::dtype::float16>::value) {
a = static_cast<void *>(&alpha);
b = static_cast<void *>(&beta);
compute_type = CUDA_R_16F;
}
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasGemmStridedBatchedEx(handle, paddle::platform::dynload::cublasGemmStridedBatchedEx(handle,
...@@ -2358,7 +2392,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -2358,7 +2392,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
N, N,
M, M,
K, K,
&alpha, a,
B, B,
fp, fp,
ldb, ldb,
...@@ -2367,13 +2401,13 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -2367,13 +2401,13 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
fp, fp,
lda, lda,
strideA, strideA,
&beta, b,
C, C,
fp, fp,
ldc, ldc,
strideC, strideC,
batchCount, batchCount,
fp, compute_type,
algo)); algo));
}); });
} else { } else {
......
...@@ -495,6 +495,58 @@ class TestMatMulV2API(unittest.TestCase): ...@@ -495,6 +495,58 @@ class TestMatMulV2API(unittest.TestCase):
y = paddle.to_tensor(input_y) y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y) result = paddle.matmul(x, y)
def test_compute_type_fp32(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
with fluid.dygraph.guard(place):
paddle.set_flags({
'FLAGS_gemm_use_half_precision_compute_type': False
})
input_x = np.random.random([2, 8, 16]).astype("float16")
input_y = np.random.random([2, 16, 8]).astype("float16")
for i in range(0, 16, 2):
input_x[:, :, i] += 60000
input_x[:, :, i + 1] -= 60000
input_y[:, :, :] = 1.5
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
result_np = np.matmul(input_x, input_y)
self.assertTrue(paddle.isfinite(result)[0, 0, 0])
self.assertTrue(np.isfinite(result_np)[0, 0, 0])
self.assertTrue(np.array_equal(result_np, result.numpy()))
paddle.set_flags({
'FLAGS_gemm_use_half_precision_compute_type': True
})
def test_compute_type_fp16_nan(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
with fluid.dygraph.guard(place):
paddle.set_flags({
'FLAGS_gemm_use_half_precision_compute_type': True
})
input_x = np.random.random([2, 8, 16]).astype("float16")
input_y = np.random.random([2, 16, 8]).astype("float16")
for i in range(0, 16, 2):
input_x[:, :, i] += 60000
input_x[:, :, i + 1] -= 60000
input_y[:, :, :] = 1.5
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
result_np = np.matmul(input_x, input_y)
self.assertFalse(
paddle.isfinite(result)[0, 0, 0]) # contains nan/inf
self.assertTrue(np.isfinite(result_np)[0, 0, 0])
paddle.set_flags({
'FLAGS_gemm_use_half_precision_compute_type': False
})
def test_api_eager_dygraph(self): def test_api_eager_dygraph(self):
with _test_eager_guard(): with _test_eager_guard():
self.test_dygraph() self.test_dygraph()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册