diff --git a/dnn/src/naive/batched_matrix_mul/opr_impl.cpp b/dnn/src/naive/batched_matrix_mul/opr_impl.cpp index 88e6bc76981d6a28ddec3d6b09f1b062ba62b790..073a3ecb4c46a02bc5d6b6765bc31cb593a26219 100644 --- a/dnn/src/naive/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/batched_matrix_mul/opr_impl.cpp @@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, } -} // namespace naive -} // namespace megdnn +std::vector +BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/) { + return {static_cast(handle()) + ->default_batched_matmul_fwd_algo()}; +} -// vim: syntax=cpp.doxygen +BatchedMatrixMulForward::Algorithm* +BatchedMatrixMulForwardImpl::get_algorithm_heuristic( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, + bool /* reproducible */) { + return static_cast(handle()) + ->default_batched_matmul_fwd_algo(); +} +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/batched_matrix_mul/opr_impl.h b/dnn/src/naive/batched_matrix_mul/opr_impl.h index 35f52aeb2bb9a38c59b48d9431632c1e141b810e..58dc718ffcac01cce6e1d781759c771baad9a709 100644 --- a/dnn/src/naive/batched_matrix_mul/opr_impl.h +++ b/dnn/src/naive/batched_matrix_mul/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -25,17 +26,13 @@ public: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, - const TensorLayout& /*C*/) override { - return {}; - } + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, - bool /* reproducible */) override { - return nullptr; - } + bool /* reproducible */) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 5e93508b86caa384580c30ead399a14788d3f38f..a2c210f3895ee446717d752959f47d30c099c757 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm DefaultLocalShareBackwardFilterAlgorithm HandleImpl::m_default_local_share_bwd_filter_algo; +DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo; +DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo; + HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, HandleType type) : HandleImplHelper(computing_handle, type), diff --git a/dnn/src/naive/handle.h b/dnn/src/naive/handle.h index 15d3c63edd01ca3f23c73391124bfa0bee33b19f..27bc57ec6c8b8cf036721bc7e4805cbb6b004da8 100644 --- a/dnn/src/naive/handle.h +++ b/dnn/src/naive/handle.h @@ -13,6 +13,7 @@ #include "src/common/handle_impl.h" #include "src/naive/convolution/algorithms.h" +#include "src/naive/matrix_mul/algorithms.h" #include "src/naive/local_share/algorithms.h" #include "src/naive/convolution3d/algorithms.h" @@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper { static DefaultLocalShareBackwardFilterAlgorithm m_default_local_share_bwd_filter_algo; + static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo; + static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo; + //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch template void move_kern_func_to_new_kern_and_dispatch(T& func) { @@ -109,6 +113,14 @@ public: return &m_default_local_share_bwd_filter_algo; } + MatrixMulForward::Algorithm* default_matmul_fwd_algo() { + return &m_default_matmul_fwd_algo; + } + + BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() { + return &m_default_batched_matmul_fwd_algo; + } + Relayout* relayout_opr() override { return get_helper_opr(this); } diff --git a/dnn/src/naive/matrix_mul/algorithms.h b/dnn/src/naive/matrix_mul/algorithms.h new file mode 100644 index 0000000000000000000000000000000000000000..88fed2352ae8764c7f847caf58271bd584edd50a --- /dev/null +++ b/dnn/src/naive/matrix_mul/algorithms.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/naive/matrix_mul/algorithms.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs/linalg.h" + +namespace megdnn { +namespace naive { + +class DefaultMatrixMulAlgorithm final + : public megdnn::MatrixMulForward::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } +}; + +class DefaultBatchedMatrixMulAlgorithm final + : public megdnn::BatchedMatrixMulForward::Algorithm { + bool is_reproducible() const override { return true; } + const char* name() const override { return "DEFAULT"; } + uint32_t type() const override { return 0; } +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index 2ba27d1f3c2b3fe313e1443ff44ae55ec8bf48f7..5141ea35a04b151671864303b37cb8367eaf2475 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, MIDOUT_END(); } +std::vector +MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/) { + return {static_cast(handle())->default_matmul_fwd_algo()}; +} + +MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, + bool /* reproducible */) { + return static_cast(handle())->default_matmul_fwd_algo(); +} + } // namespace naive } // namespace megdnn diff --git a/dnn/src/naive/matrix_mul/opr_impl.h b/dnn/src/naive/matrix_mul/opr_impl.h index 61449f01d1d2c97aeb5a330862e9d87ca5a77e17..6ed0c72f220e53d40ab3c11627d000720a3309d2 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.h +++ b/dnn/src/naive/matrix_mul/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -26,17 +27,13 @@ public: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, - const TensorLayout& /*C*/) override { - return {}; - } + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, - bool /* reproducible */) override { - return nullptr; - } + bool /* reproducible */) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; }