提交 85fa9883 编写于 作者: M Megvii Engine Team

refactor(dnn): add get_algorithm_from_desc interface

GitOrigin-RevId: 6d211ca1676d43b8b4eeed751d30468097a08b5c
上级 43b4d4a4
......@@ -188,6 +188,7 @@ public:
using AlgorithmInfo = detail::Algorithm::Info;
using AlgorithmDesc = detail::Algorithm::Info::Desc;
using Algorithm = detail::Algorithm;
/*!
* \brief get a string representation for current algorithm set;
*
......@@ -209,6 +210,8 @@ public:
return m_execution_policy;
}
virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0;
protected:
~MultiAlgoOpr() = default;
......
......@@ -38,11 +38,12 @@ namespace megdnn {
return algo_pack().all_algos_map().at(desc); \
}
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::Algorithm* _opr::get_algorithm_from_desc( \
const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}
/**
......
......@@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false);
}
return opr->get_algo_from_desc(ret.desc);
return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_from_desc(ret.desc));
}
/*!
......@@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
*/
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
typename Opr::AlgorithmInfo ret;
auto set = opr->execution_policy().algo;
if (set.valid()) {
return opr->algo_pack().construct_and_get_algo(set.desc);
......
......@@ -35,7 +35,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......
......@@ -39,7 +39,7 @@ public:
bool is_thread_safe() const override { return true; }
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
......
......@@ -69,7 +69,7 @@ public:
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
......
......@@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
workspace_limit_in_bytes, reproducible);
}
ConvolutionForwardImpl::Algorithm*
ConvolutionForwardImpl::get_algorithm_from_desc(
const ConvolutionForward::AlgorithmDesc& desc) {
auto conv_param = param();
auto convbias_opr = this->handle()->create_operator<ConvBiasForward>();
convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY,
conv_param.mode,
conv_param.sparse,
conv_param.format,
conv_param.pad_h,
conv_param.pad_w,
conv_param.stride_h,
conv_param.stride_w,
conv_param.dilate_h,
conv_param.dilate_w,
conv_param.compute_mode};
convbias_opr->execution_policy() = {this->execution_policy().algo};
return static_cast<ConvBiasForwardImpl*>(convbias_opr.get())
->get_algorithm_from_desc(desc);
}
std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& filter,
......
......@@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward {
megdnn_throw("cuda exec_preprocess has not implemeted yet");
}
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
struct ConvBiasExtraData{
std::unique_ptr<ConvBiasForward> convbias_opr;
......@@ -98,7 +100,7 @@ public:
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -152,7 +154,7 @@ public:
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......
......@@ -42,7 +42,7 @@ public:
class AlgoGroupConvGeneral;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -92,7 +92,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -143,7 +143,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......
......@@ -46,7 +46,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -97,7 +97,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -151,7 +151,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......
......@@ -33,7 +33,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -65,7 +65,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......@@ -98,7 +98,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(
......
......@@ -46,7 +46,7 @@ public:
static const AlgoPack& algo_pack() {
return sm_algo_pack;
}
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
......
......@@ -29,8 +29,7 @@ public:
class AlgoDefault;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
private:
std::vector<Algorithm*> get_all_algorithms(
......
......@@ -454,8 +454,8 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
return algos;
}
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) {
return nullptr;
} else {
......@@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc(
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) {
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) {
return algo;
}
if (!m_prev_selected_algo ||
......
......@@ -381,7 +381,7 @@ private:
bool is_naive_algo(ConvBiasImpl::Algorithm* algo);
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
//! get algorithm set by user or by heuristic
Algorithm* get_algorithm(
......
......@@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
return ret;
}
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) {
return nullptr;
} else {
......@@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc(
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) {
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) {
return algo;
}
if (!m_prev_selected_algo ||
......@@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
}
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
ConvolutionBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) {
return nullptr;
} else {
......@@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc(
ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) {
return algo;
}
if (!m_prev_selected_algo ||
......
......@@ -284,7 +284,7 @@ private:
NCBKernSizeParam m_prev_selected_algo_sizep;
Algorithm* m_prev_selected_algo = nullptr;
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
bool is_naive_algo(ConvolutionImpl::Algorithm* algo);
Algorithm* get_algorithm(
const NCBKernSizeParam& param,
......@@ -493,7 +493,7 @@ private:
class AlgoDirect;
class AlgoMatrixMul;
class AlgoPack;
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
public:
//! maintain all the algos of in the opr of fallback
......
......@@ -96,7 +96,7 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
return gemv_algos;
}
MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc(
MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) {
return nullptr;
......@@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) {
auto kern_size_param = make_kern_size_param(A, B, C);
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = static_cast<AlgoBase*>(
get_algorithm_from_desc(execution_policy().algo.desc))) {
megdnn_assert(algo->get_workspace(kern_size_param) <
workspace_limit_in_bytes);
auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo,
......
......@@ -238,7 +238,8 @@ private:
class AlgoPack;
//! maintain all the algos of in the opr of fallback
static const AlgoPack& algo_pack();
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
public:
/**
......
......@@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
return algo;
}
BatchConvBiasForward::Algorithm*
BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_batch_conv_bias_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
// vim: syntax=cpp.doxygen
......@@ -39,6 +39,8 @@ public:
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
......
......@@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
->default_batched_matmul_fwd_algo();
}
BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
} // namespace naive
} // namespace megdnn
......
......@@ -34,6 +34,8 @@ public:
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
private:
......
......@@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
return algo;
}
ConvBiasForward::Algorithm*
ConvBiasForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
const char* ConvBiasForwardImpl::get_algorithm_set_name() const {
return "DEFAULT";
}
......
......@@ -64,6 +64,8 @@ public:
_megdnn_workspace) override {}
const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
};
void handle_z_inp_and_activation_naive(
......
......@@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
return algo;
}
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
std::vector<ConvolutionBackwardData::Algorithm *>
ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
......@@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
return algo;
}
ConvolutionBackwardData::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
std::vector<ConvolutionBackwardFilter::Algorithm *>
ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
......@@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
return algo;
}
ConvolutionBackwardFilter::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
const char* ConvolutionForwardImpl::get_algorithm_set_name() const {
return "DEFAULT";
}
......
......@@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward {
return {};
}
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
};
......@@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const TensorLayout&) override;
const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
};
class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
......@@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
const TensorLayout&) override;
const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
};
} // namespace naive
......
......@@ -6,15 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "./helper.h"
#include "./opr_impl.h"
#include "src/naive/handle.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cstring>
......@@ -25,93 +25,95 @@ using namespace megdnn;
using namespace naive;
void Convolution3DForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace)
{
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) {
auto filter_meta = check_exec(
src.layout, filter.layout, dst.layout, workspace.size);
switch (param().data_type) {
case Param::DataType::FLOAT:
#define cb(dt) do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::forward< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, filter, dst, filter_meta); \
); \
return; \
} \
} while(0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
auto filter_meta = check_exec(src.layout, filter.layout, dst.layout,
workspace.size);
switch (param().data_type) {
case Param::DataType::FLOAT:
#define cb(dt) \
do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::forward< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, filter, dst, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb
break;
case Param::DataType::FLOAT_IO16xC32:
MEGDNN_INC_FLOAT16(
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()),
convolution3d::forward<
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA dt_float32>(
src, filter, dst, filter_meta);));
return;
break;
case Param::DataType::FLOAT_IO16xC32:
MEGDNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN(
static_cast<HandleImpl*>(handle()),
convolution3d::forward<
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA
dt_float32>(src, filter, dst,
filter_meta);));
return;
}
megdnn_assert_internal(0);
}
megdnn_assert_internal(0);
} MIDOUT_END();
MIDOUT_END();
}
void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
{
auto filter_meta = check_exec(
filter.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) do { \
if (filter.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::backward_data< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
filter, diff, grad, filter_meta);); \
return; \
} \
} while(0);
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
auto filter_meta =
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) \
do { \
if (filter.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::backward_data< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
filter, diff, grad, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb
megdnn_assert_internal(0);
}
void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
{
auto filter_meta = check_exec(
src.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, diff, grad, filter_meta);); \
return; \
} \
} while(0);
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
auto filter_meta =
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) \
do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, diff, grad, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb
megdnn_assert_internal(0);
}
std::vector<Convolution3DForward::Algorithm *>
Convolution3DForwardImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl *>(handle())->default_conv3d_fwd_algo()};
std::vector<Convolution3DForward::Algorithm*>
Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()};
}
Convolution3DForward::Algorithm*
......@@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
return algo;
}
std::vector<Convolution3DBackwardData::Algorithm *>
Convolution3DBackwardDataImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl *>(handle())->default_conv3d_bwd_data_algo()};
Convolution3DForward::Algorithm*
Convolution3DForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
std::vector<Convolution3DBackwardData::Algorithm*>
Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()};
}
Convolution3DBackwardData::Algorithm*
......@@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
return algo;
}
std::vector<Convolution3DBackwardFilter::Algorithm *>
Convolution3DBackwardFilterImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_filter_algo()};
Convolution3DBackwardData::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
std::vector<Convolution3DBackwardFilter::Algorithm*>
Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo()};
}
Convolution3DBackwardFilter::Algorithm*
......@@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
return algo;
}
Convolution3DBackwardFilter::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
const char* Convolution3DForwardImpl::get_algorithm_set_name() const {
return "DEFAULT";
}
......
......@@ -6,81 +6,79 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class Convolution3DForwardImpl: public Convolution3DForward {
public:
using Convolution3DForward::Convolution3DForward;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const char* get_algorithm_set_name() const override;
class Convolution3DForwardImpl : public Convolution3DForward {
public:
using Convolution3DForward::Convolution3DForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
};
class Convolution3DBackwardDataImpl: public Convolution3DBackwardData {
public:
using Convolution3DBackwardData::Convolution3DBackwardData;
void exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
class Convolution3DBackwardDataImpl : public Convolution3DBackwardData {
public:
using Convolution3DBackwardData::Convolution3DBackwardData;
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
};
class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter {
public:
using Convolution3DBackwardFilter::Convolution3DBackwardFilter;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const char* get_algorithm_set_name() const override;
class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter {
public:
using Convolution3DBackwardFilter::Convolution3DBackwardFilter;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -48,6 +48,10 @@ public:
return "DEFORMABLE_CONV2_NAIVE";
};
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override {
return {};
}
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
......@@ -84,6 +88,10 @@ public:
return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE";
};
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override {
return {};
}
void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
_megdnn_tensor_out filter_grad,
......@@ -130,6 +138,10 @@ public:
return "DEFORMABLE_CONV2_BWD_DATA_NAIVE";
};
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override {
return {};
}
void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
_megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
......
......@@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
return algo;
}
LocalShareForward::Algorithm*
LocalShareForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
std::vector<LocalShareBackwardData::Algorithm*>
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
......@@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic(
return algo;
}
LocalShareBackwardData::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
std::vector<LocalShareBackwardFilter::Algorithm*>
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
......@@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic(
return algo;
}
LocalShareBackwardFilter::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
......@@ -35,6 +36,7 @@ public:
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
};
......@@ -59,6 +61,7 @@ public:
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
};
......@@ -83,6 +86,7 @@ public:
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
};
......
......@@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
}
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
} // namespace naive
} // namespace megdnn
......
......@@ -35,6 +35,8 @@ public:
size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; }
private:
......
......@@ -29,8 +29,8 @@ public:
class AlgoBlas;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
private:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
......
......@@ -66,7 +66,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
private:
std::vector<Algorithm*> get_all_algorithms(
......@@ -112,7 +112,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
private:
std::vector<Algorithm*> get_all_algorithms(
......@@ -158,7 +158,7 @@ public:
class AlgoPack;
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
private:
......
......@@ -29,7 +29,7 @@ public:
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
private:
std::vector<Algorithm*> get_all_algorithms(
......@@ -41,6 +41,7 @@ private:
const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override;
const char* get_algorithm_set_name() const override {
return "ROCM MATMUL";
}
......
......@@ -2204,6 +2204,10 @@ public:
const TensorLayout& p2,
size_t workspace_limit_in_bytes,
bool reproducible));
MOCK_METHOD1(get_algorithm_from_desc,
Algorithm*(const AlgorithmDesc&));
protected:
const char* get_algorithm_set_name() const override {
return m_algorithm_set_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册