提交 87ff58f7 编写于 作者: M Megvii Engine Team

fix(megdnn): add algo for matmul/batchedmatrixmul of naive and opencl

GitOrigin-RevId: 2409b6ba164b40158da87e905aa53a1b1722a1c2
上级 169fa53d
...@@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, ...@@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A,
} }
std::vector<BatchedMatrixMulForward::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo()};
}
BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) {
return static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo();
}
} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
...@@ -25,17 +26,13 @@ public: ...@@ -25,17 +26,13 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override { const TensorLayout& /*C*/) override;
return {};
}
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override { bool /* reproducible */) override;
return nullptr;
}
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......
...@@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm ...@@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm
DefaultLocalShareBackwardFilterAlgorithm DefaultLocalShareBackwardFilterAlgorithm
HandleImpl::m_default_local_share_bwd_filter_algo; HandleImpl::m_default_local_share_bwd_filter_algo;
DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo;
DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo;
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type) HandleType type)
: HandleImplHelper(computing_handle, type), : HandleImplHelper(computing_handle, type),
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "src/common/handle_impl.h" #include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h" #include "src/naive/convolution/algorithms.h"
#include "src/naive/matrix_mul/algorithms.h"
#include "src/naive/local_share/algorithms.h" #include "src/naive/local_share/algorithms.h"
#include "src/naive/convolution3d/algorithms.h" #include "src/naive/convolution3d/algorithms.h"
...@@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper { ...@@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper {
static DefaultLocalShareBackwardFilterAlgorithm static DefaultLocalShareBackwardFilterAlgorithm
m_default_local_share_bwd_filter_algo; m_default_local_share_bwd_filter_algo;
static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo;
static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo;
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template <typename T> template <typename T>
void move_kern_func_to_new_kern_and_dispatch(T& func) { void move_kern_func_to_new_kern_and_dispatch(T& func) {
...@@ -109,6 +113,14 @@ public: ...@@ -109,6 +113,14 @@ public:
return &m_default_local_share_bwd_filter_algo; return &m_default_local_share_bwd_filter_algo;
} }
MatrixMulForward::Algorithm* default_matmul_fwd_algo() {
return &m_default_matmul_fwd_algo;
}
BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() {
return &m_default_batched_matmul_fwd_algo;
}
Relayout* relayout_opr() override { Relayout* relayout_opr() override {
return get_helper_opr<Relayout, 2>(this); return get_helper_opr<Relayout, 2>(this);
} }
......
/**
* \file dnn/src/naive/matrix_mul/algorithms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs/linalg.h"
namespace megdnn {
namespace naive {
class DefaultMatrixMulAlgorithm final
: public megdnn::MatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
class DefaultBatchedMatrixMulAlgorithm final
: public megdnn::BatchedMatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, ...@@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
MIDOUT_END(); MIDOUT_END();
} }
std::vector<MatrixMulForward::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
}
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) {
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
}
} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
...@@ -26,17 +27,13 @@ public: ...@@ -26,17 +27,13 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override { const TensorLayout& /*C*/) override;
return {};
}
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override { bool /* reproducible */) override;
return nullptr;
}
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册