提交 d400b419 编写于 作者: K Kexin Zhao

fix math function arch mismatch for older GPU

上级 ccc54188
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include <iostream>
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
const std::vector<float>& data) { const std::vector<float>& data) {
PADDLE_ENFORCE_EQ(size, data.size()); PADDLE_ENFORCE_EQ(size, data.size());
...@@ -22,6 +24,15 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, ...@@ -22,6 +24,15 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
} }
} }
bool is_fp16_supported(int device_id) {
cudaDeviceProp device_prop;
cudaDeviceProperties(&device_prop, device_id);
PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess);
int compute_capability = device_prop.major * 10 + device_prop.minor;
std::cout << "compute_capability is " << compute_capability << std::endl;
return compute_capability >= 53;
}
TEST(math_function, notrans_mul_trans_fp32) { TEST(math_function, notrans_mul_trans_fp32) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
...@@ -62,6 +73,10 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -62,6 +73,10 @@ TEST(math_function, notrans_mul_trans_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
if (!is_fp16_supported(0)) {
return;
}
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
Tensor input2_gpu; Tensor input2_gpu;
...@@ -139,6 +154,10 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -139,6 +154,10 @@ TEST(math_function, trans_mul_notrans_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
if (!is_fp16_supported(0)) {
return;
}
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
Tensor input2_gpu; Tensor input2_gpu;
...@@ -237,6 +256,10 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -237,6 +256,10 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
if (!is_fp16_supported(0)) {
return;
}
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
Tensor input3; Tensor input3;
...@@ -344,6 +367,10 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -344,6 +367,10 @@ TEST(math_function, gemm_trans_cublas_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
if (!is_fp16_supported(0)) {
return;
}
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
Tensor input3; Tensor input3;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册