提交 f5833a52 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/cuda): fix cublas matmul on sm60

GitOrigin-RevId: 3fc0c30a23f1dfe35d6629595b2cb1a8c2f379c5
上级 6bd09b38
......@@ -21,7 +21,7 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available(
const SizeArgs& args) const {
if (args.z_layout->ndim > 0)
return false;
if (cuda::current_device_prop().major < 6)
if (!is_compute_capability_required(6, 1))
return false;
auto dst_layout = *args.dst_layout;
......
......@@ -42,7 +42,7 @@ bool MatrixMulForwardImpl::AlgoCuBlas::is_available(
*/
return args.layout_a.stride[0] % 4 == 0 &&
args.layout_b.stride[0] % 4 == 0 &&
current_device_prop().major > 5;
is_compute_capability_required(6, 1);
}
return false;
}
......
......@@ -24,7 +24,7 @@ namespace test {
TEST_F(CUDA, BENCHMARK_CONVOLUTION_8X8X32)
{
if (cuda::current_device_prop().major < 6) {
if (!cuda::is_compute_capability_required(6, 1)) {
printf("Skip CUDA.BENCHMARK_CONVOLUTION_8X8X32 test as current device"
"doesn't support\n");
return;
......
......@@ -325,7 +325,7 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_SMALL) {
}
TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_8x8x32) {
require_compute_capability(6, 0);
require_compute_capability(6, 1);
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ConvBiasForward::algo_name<ConvBias::DirectParam>(
......@@ -472,7 +472,7 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL) {
}
TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_8x8x32) {
require_compute_capability(6, 0);
require_compute_capability(6, 1);
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
......@@ -517,7 +517,7 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_8x8x32) {
}
TEST_F(CUDA, CONV_BIAS_FORWARD_MATMUL_NCHW4) {
require_compute_capability(6, 0);
require_compute_capability(6, 1);
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ConvBiasForward::algo_name<ConvBiasForward::MatmulParam>(
......
......@@ -30,7 +30,7 @@ namespace test {
TEST_F(CUDA, CONVOLUTION_8X8X32)
{
if (cuda::current_device_prop().major < 6) {
if (!cuda::is_compute_capability_required(6, 1)) {
printf("Skip CUDA.CONVOLUTION_8X8X32 test as current device"
"doesn't support\n");
return;
......@@ -112,7 +112,7 @@ TEST_F(CUDA, CONVOLUTION_FORWARD)
}
TEST_F(CUDA, CONV_FORWARD_MATMUL_NCHW4) {
if (cuda::current_device_prop().major < 6)
if (!cuda::is_compute_capability_required(6, 1))
return;
using namespace convolution;
Checker<Convolution> checker(handle_cuda());
......
......@@ -24,7 +24,7 @@ namespace test {
#if 0
TEST_F(CUDA, CONVOLUTION3D_8X8X32) {
if (cuda::current_device_prop().major < 6) {
if (!cuda::is_compute_capability_required(6, 1)) {
printf("Skip CUDA.CONVOLUTION_8X8X32 test as current device"
"doesn't support\n");
return;
......
......@@ -23,7 +23,7 @@ namespace test {
TEST_F(CUDA, GROUP_CONV_FORWARD)
{
bool is_int_available = (cuda::current_device_prop().major >= 6);
bool is_int_available = cuda::is_compute_capability_required(6, 1);
auto run = [&](size_t N, size_t IC, size_t IH, size_t IW,
size_t FH, size_t FW,
size_t OC, size_t /* OH */, size_t /* OW */,
......
......@@ -21,7 +21,7 @@ namespace megdnn {
namespace test {
TEST_F(CUDA, GROUP_CONVOLUTION3D_FORWARD) {
bool is_int_available = (cuda::current_device_prop().major >= 6);
bool is_int_available = cuda::is_compute_capability_required(6, 1);
static_cast<void>(is_int_available);
auto run = [&](size_t N, size_t IC, size_t ID, size_t IH, size_t IW,
size_t FD, size_t FH, size_t FW, size_t OC, size_t PD,
......
......@@ -193,8 +193,15 @@ TEST_F(CUDA, MATRIX_MUL)
Checker<MatrixMul> checker(handle_cuda());
using Param = MatrixMul::Param;
size_t m = 12, n = 16, k = 20;
for (DType dtype: std::array<DType, 3>{
{dtype::Float32(), dtype::Float16(), dtype::Int32()}}) {
bool is_int_available = cuda::is_compute_capability_required(6, 1);
std::vector<DType> dtype_array;
dtype_array.push_back(dtype::Float32());
dtype_array.push_back(dtype::Float16());
if (is_int_available)
dtype_array.push_back(dtype::Int32());
for (DType dtype : dtype_array) {
for (unsigned mask = 0; mask < 4; ++mask) {
Param param;
param.transposeA = mask & 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册