提交 87ff58f7 编写于 作者: M Megvii Engine Team

fix(megdnn): add algo for matmul/batchedmatrixmul of naive and opencl

GitOrigin-RevId: 2409b6ba164b40158da87e905aa53a1b1722a1c2
上级 169fa53d
......@@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A,
}
} // namespace naive
} // namespace megdnn
std::vector<BatchedMatrixMulForward::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo()};
}
// vim: syntax=cpp.doxygen
BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) {
return static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo();
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -25,17 +26,13 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override {
return nullptr;
}
bool /* reproducible */) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......
......@@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm
DefaultLocalShareBackwardFilterAlgorithm
HandleImpl::m_default_local_share_bwd_filter_algo;
DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo;
DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo;
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type)
: HandleImplHelper(computing_handle, type),
......
......@@ -13,6 +13,7 @@
#include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/matrix_mul/algorithms.h"
#include "src/naive/local_share/algorithms.h"
#include "src/naive/convolution3d/algorithms.h"
......@@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper {
static DefaultLocalShareBackwardFilterAlgorithm
m_default_local_share_bwd_filter_algo;
static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo;
static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo;
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template <typename T>
void move_kern_func_to_new_kern_and_dispatch(T& func) {
......@@ -109,6 +113,14 @@ public:
return &m_default_local_share_bwd_filter_algo;
}
MatrixMulForward::Algorithm* default_matmul_fwd_algo() {
return &m_default_matmul_fwd_algo;
}
BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() {
return &m_default_batched_matmul_fwd_algo;
}
Relayout* relayout_opr() override {
return get_helper_opr<Relayout, 2>(this);
}
......
/**
* \file dnn/src/naive/matrix_mul/algorithms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs/linalg.h"
namespace megdnn {
namespace naive {
class DefaultMatrixMulAlgorithm final
: public megdnn::MatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
class DefaultBatchedMatrixMulAlgorithm final
: public megdnn::BatchedMatrixMulForward::Algorithm {
bool is_reproducible() const override { return true; }
const char* name() const override { return "DEFAULT"; }
uint32_t type() const override { return 0; }
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
MIDOUT_END();
}
std::vector<MatrixMulForward::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
}
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) {
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
}
} // namespace naive
} // namespace megdnn
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -26,17 +27,13 @@ public:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override {
return {};
}
const TensorLayout& /*C*/) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override {
return nullptr;
}
bool /* reproducible */) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册