diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index 57d8941b43c36afcc4f96c06d7b7470328c6996f..c9c01692edebd29fad62deb4de11be8e7e112363 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -29,7 +29,6 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { #if CUDA_VERSION >= 10010 all_algos.push_back(&cublas_lt); #endif - all_algos.push_back(&naive); #if !MEGDNN_DISABLE_FLOAT16 all_algos.push_back(&bfloat16); #endif @@ -45,6 +44,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo); } #endif + all_algos.push_back(&naive); for (auto&& algo : all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index b783cf8d77f5ff492bd3348eed256f2ab2661f49..5bbb924542a847b97b6495bfaa3dcaa098ae3511 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -157,7 +157,7 @@ public: void exec(const ExecArgs& args) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; } };