Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
85fa9883
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
85fa9883
编写于
12月 08, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn): add get_algorithm_from_desc interface
GitOrigin-RevId: 6d211ca1676d43b8b4eeed751d30468097a08b5c
上级
43b4d4a4
变更
38
隐藏空白更改
内联
并排
Showing
38 changed file
with
373 addition
and
194 deletion
+373
-194
dnn/include/megdnn/oprs/base.h
dnn/include/megdnn/oprs/base.h
+3
-0
dnn/src/common/algo_base.h
dnn/src/common/algo_base.h
+6
-5
dnn/src/common/algo_chooser.h
dnn/src/common/algo_chooser.h
+2
-2
dnn/src/cuda/batch_conv_bias/opr_impl.h
dnn/src/cuda/batch_conv_bias/opr_impl.h
+1
-1
dnn/src/cuda/batched_matrix_mul/opr_impl.h
dnn/src/cuda/batched_matrix_mul/opr_impl.h
+1
-1
dnn/src/cuda/conv_bias/opr_impl.h
dnn/src/cuda/conv_bias/opr_impl.h
+1
-1
dnn/src/cuda/convolution/opr_impl.cpp
dnn/src/cuda/convolution/opr_impl.cpp
+22
-0
dnn/src/cuda/convolution/opr_impl.h
dnn/src/cuda/convolution/opr_impl.h
+4
-2
dnn/src/cuda/convolution3d/opr_impl.h
dnn/src/cuda/convolution3d/opr_impl.h
+3
-3
dnn/src/cuda/deformable_conv/opr_impl.h
dnn/src/cuda/deformable_conv/opr_impl.h
+3
-3
dnn/src/cuda/local_share/opr_impl.h
dnn/src/cuda/local_share/opr_impl.h
+3
-3
dnn/src/cuda/matrix_mul/opr_impl.h
dnn/src/cuda/matrix_mul/opr_impl.h
+1
-1
dnn/src/fallback/batched_matrix_mul/opr_impl.h
dnn/src/fallback/batched_matrix_mul/opr_impl.h
+1
-2
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+3
-3
dnn/src/fallback/conv_bias/opr_impl.h
dnn/src/fallback/conv_bias/opr_impl.h
+1
-1
dnn/src/fallback/convolution/opr_impl.cpp
dnn/src/fallback/convolution/opr_impl.cpp
+6
-6
dnn/src/fallback/convolution/opr_impl.h
dnn/src/fallback/convolution/opr_impl.h
+2
-2
dnn/src/fallback/matrix_mul/opr_impl.cpp
dnn/src/fallback/matrix_mul/opr_impl.cpp
+3
-2
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+2
-1
dnn/src/naive/batch_conv_bias/opr_impl.cpp
dnn/src/naive/batch_conv_bias/opr_impl.cpp
+8
-0
dnn/src/naive/batch_conv_bias/opr_impl.h
dnn/src/naive/batch_conv_bias/opr_impl.h
+2
-0
dnn/src/naive/batched_matrix_mul/opr_impl.cpp
dnn/src/naive/batched_matrix_mul/opr_impl.cpp
+9
-0
dnn/src/naive/batched_matrix_mul/opr_impl.h
dnn/src/naive/batched_matrix_mul/opr_impl.h
+2
-0
dnn/src/naive/conv_bias/opr_impl.cpp
dnn/src/naive/conv_bias/opr_impl.cpp
+9
-0
dnn/src/naive/conv_bias/opr_impl.h
dnn/src/naive/conv_bias/opr_impl.h
+2
-0
dnn/src/naive/convolution/convolution.cpp
dnn/src/naive/convolution/convolution.cpp
+26
-0
dnn/src/naive/convolution/opr_impl.h
dnn/src/naive/convolution/opr_impl.h
+6
-0
dnn/src/naive/convolution3d/convolution3d.cpp
dnn/src/naive/convolution3d/convolution3d.cpp
+114
-84
dnn/src/naive/convolution3d/opr_impl.h
dnn/src/naive/convolution3d/opr_impl.h
+63
-65
dnn/src/naive/deformable_conv/opr_impl.h
dnn/src/naive/deformable_conv/opr_impl.h
+12
-0
dnn/src/naive/local_share/opr_impl.cpp
dnn/src/naive/local_share/opr_impl.cpp
+27
-0
dnn/src/naive/local_share/opr_impl.h
dnn/src/naive/local_share/opr_impl.h
+5
-1
dnn/src/naive/matrix_mul/opr_impl.cpp
dnn/src/naive/matrix_mul/opr_impl.cpp
+8
-0
dnn/src/naive/matrix_mul/opr_impl.h
dnn/src/naive/matrix_mul/opr_impl.h
+2
-0
dnn/src/rocm/batched_matrix_mul/opr_impl.h
dnn/src/rocm/batched_matrix_mul/opr_impl.h
+1
-1
dnn/src/rocm/convolution/opr_impl.h
dnn/src/rocm/convolution/opr_impl.h
+3
-3
dnn/src/rocm/matrix_mul/opr_impl.h
dnn/src/rocm/matrix_mul/opr_impl.h
+2
-1
src/opr/test/dnn/convolution.cpp
src/opr/test/dnn/convolution.cpp
+4
-0
未找到文件。
dnn/include/megdnn/oprs/base.h
浏览文件 @
85fa9883
...
...
@@ -188,6 +188,7 @@ public:
using
AlgorithmInfo
=
detail
::
Algorithm
::
Info
;
using
AlgorithmDesc
=
detail
::
Algorithm
::
Info
::
Desc
;
using
Algorithm
=
detail
::
Algorithm
;
/*!
* \brief get a string representation for current algorithm set;
*
...
...
@@ -209,6 +210,8 @@ public:
return
m_execution_policy
;
}
virtual
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
=
0
;
protected:
~
MultiAlgoOpr
()
=
default
;
...
...
dnn/src/common/algo_base.h
浏览文件 @
85fa9883
...
...
@@ -38,11 +38,12 @@ namespace megdnn {
return algo_pack().all_algos_map().at(desc); \
}
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::Algorithm* _opr::get_algorithm_from_desc( \
const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}
/**
...
...
dnn/src/common/algo_chooser.h
浏览文件 @
85fa9883
...
...
@@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
std
::
forward
<
Args
>
(
args
)...,
std
::
numeric_limits
<
size_t
>::
max
(),
false
);
}
return
opr
->
get_algo_from_desc
(
ret
.
desc
);
return
static_cast
<
typename
Opr
::
AlgoBase
*>
(
opr
->
get_algorithm_from_desc
(
ret
.
desc
));
}
/*!
...
...
@@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
*/
template
<
class
Opr
,
typename
...
Args
>
typename
Opr
::
AlgoBase
*
get_algorithm_or_construct
(
Opr
*
opr
,
Args
&&
...
args
)
{
typename
Opr
::
AlgorithmInfo
ret
;
auto
set
=
opr
->
execution_policy
().
algo
;
if
(
set
.
valid
())
{
return
opr
->
algo_pack
().
construct_and_get_algo
(
set
.
desc
);
...
...
dnn/src/cuda/batch_conv_bias/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -35,7 +35,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
dnn/src/cuda/batched_matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -39,7 +39,7 @@ public:
bool
is_thread_safe
()
const
override
{
return
true
;
}
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
A
,
...
...
dnn/src/cuda/conv_bias/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -69,7 +69,7 @@ public:
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
dnn/src/cuda/convolution/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
workspace_limit_in_bytes
,
reproducible
);
}
ConvolutionForwardImpl
::
Algorithm
*
ConvolutionForwardImpl
::
get_algorithm_from_desc
(
const
ConvolutionForward
::
AlgorithmDesc
&
desc
)
{
auto
conv_param
=
param
();
auto
convbias_opr
=
this
->
handle
()
->
create_operator
<
ConvBiasForward
>
();
convbias_opr
->
param
()
=
{
param
::
ConvBias
::
NonlineMode
::
IDENTITY
,
conv_param
.
mode
,
conv_param
.
sparse
,
conv_param
.
format
,
conv_param
.
pad_h
,
conv_param
.
pad_w
,
conv_param
.
stride_h
,
conv_param
.
stride_w
,
conv_param
.
dilate_h
,
conv_param
.
dilate_w
,
conv_param
.
compute_mode
};
convbias_opr
->
execution_policy
()
=
{
this
->
execution_policy
().
algo
};
return
static_cast
<
ConvBiasForwardImpl
*>
(
convbias_opr
.
get
())
->
get_algorithm_from_desc
(
desc
);
}
std
::
vector
<
ConvolutionForwardImpl
::
Algorithm
*>
ConvolutionForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
dnn/src/cuda/convolution/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward {
megdnn_throw
(
"cuda exec_preprocess has not implemeted yet"
);
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
struct
ConvBiasExtraData
{
std
::
unique_ptr
<
ConvBiasForward
>
convbias_opr
;
...
...
@@ -98,7 +100,7 @@ public:
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -152,7 +154,7 @@ public:
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
dnn/src/cuda/convolution3d/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -42,7 +42,7 @@ public:
class
AlgoGroupConvGeneral
;
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -92,7 +92,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -143,7 +143,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
dnn/src/cuda/deformable_conv/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -46,7 +46,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -97,7 +97,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -151,7 +151,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
dnn/src/cuda/local_share/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -33,7 +33,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -65,7 +65,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -98,7 +98,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
dnn/src/cuda/matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -46,7 +46,7 @@ public:
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
A
,
...
...
dnn/src/fallback/batched_matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -29,8 +29,7 @@ public:
class
AlgoDefault
;
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
);
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -454,8 +454,8 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
return
algos
;
}
ConvBiasImpl
::
Algorithm
*
ConvBiasImpl
::
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
const
{
ConvBiasImpl
::
Algorithm
*
ConvBiasImpl
::
get_algo
rithm
_from_desc
(
const
AlgorithmDesc
&
desc
)
{
if
(
!
desc
.
valid
())
{
return
nullptr
;
}
else
{
...
...
@@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc(
ConvBiasImpl
::
Algorithm
*
ConvBiasImpl
::
get_algorithm
(
const
NCBKernSizeParam
&
param
,
size_t
workspace_size
)
{
if
(
auto
algo
=
get_algo_from_desc
(
execution_policy
().
algo
.
desc
))
{
if
(
auto
algo
=
get_algo
rithm
_from_desc
(
execution_policy
().
algo
.
desc
))
{
return
algo
;
}
if
(
!
m_prev_selected_algo
||
...
...
dnn/src/fallback/conv_bias/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -381,7 +381,7 @@ private:
bool
is_naive_algo
(
ConvBiasImpl
::
Algorithm
*
algo
);
Algorithm
*
get_algo
_from_desc
(
const
AlgorithmDesc
&
desc
)
const
;
Algorithm
*
get_algo
rithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
//! get algorithm set by user or by heuristic
Algorithm
*
get_algorithm
(
...
...
dnn/src/fallback/convolution/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
return
ret
;
}
ConvolutionImpl
::
Algorithm
*
ConvolutionImpl
::
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
const
{
ConvolutionImpl
::
Algorithm
*
ConvolutionImpl
::
get_algo
rithm
_from_desc
(
const
AlgorithmDesc
&
desc
)
{
if
(
!
desc
.
valid
())
{
return
nullptr
;
}
else
{
...
...
@@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc(
ConvolutionImpl
::
Algorithm
*
ConvolutionImpl
::
get_algorithm
(
const
NCBKernSizeParam
&
param
,
size_t
workspace_size
)
{
if
(
auto
algo
=
get_algo_from_desc
(
execution_policy
().
algo
.
desc
))
{
if
(
auto
algo
=
get_algo
rithm
_from_desc
(
execution_policy
().
algo
.
desc
))
{
return
algo
;
}
if
(
!
m_prev_selected_algo
||
...
...
@@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
}
ConvolutionBackwardDataImpl
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
const
{
ConvolutionBackwardDataImpl
::
get_algo
rithm
_from_desc
(
const
AlgorithmDesc
&
desc
)
{
if
(
!
desc
.
valid
())
{
return
nullptr
;
}
else
{
...
...
@@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc(
ConvolutionBackwardDataImpl
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algorithm
(
const
NCBKernSizeParam
&
param
)
{
if
(
auto
algo
=
get_algo_from_desc
(
execution_policy
().
algo
.
desc
))
{
if
(
auto
algo
=
get_algo
rithm
_from_desc
(
execution_policy
().
algo
.
desc
))
{
return
algo
;
}
if
(
!
m_prev_selected_algo
||
...
...
dnn/src/fallback/convolution/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -284,7 +284,7 @@ private:
NCBKernSizeParam
m_prev_selected_algo_sizep
;
Algorithm
*
m_prev_selected_algo
=
nullptr
;
Algorithm
*
get_algo
_from_desc
(
const
AlgorithmDesc
&
desc
)
const
;
Algorithm
*
get_algo
rithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
bool
is_naive_algo
(
ConvolutionImpl
::
Algorithm
*
algo
);
Algorithm
*
get_algorithm
(
const
NCBKernSizeParam
&
param
,
...
...
@@ -493,7 +493,7 @@ private:
class
AlgoDirect
;
class
AlgoMatrixMul
;
class
AlgoPack
;
Algorithm
*
get_algo
_from_desc
(
const
AlgorithmDesc
&
desc
)
const
;
Algorithm
*
get_algo
rithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
public:
//! maintain all the algos of in the opr of fallback
...
...
dnn/src/fallback/matrix_mul/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -96,7 +96,7 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
return
gemv_algos
;
}
MatrixMulImpl
::
Algo
Base
*
MatrixMulImpl
::
get_algo
_from_desc
(
MatrixMulImpl
::
Algo
rithm
*
MatrixMulImpl
::
get_algorithm
_from_desc
(
const
AlgorithmDesc
&
desc
)
{
if
(
!
desc
.
valid
())
{
return
nullptr
;
...
...
@@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
{
auto
kern_size_param
=
make_kern_size_param
(
A
,
B
,
C
);
if
(
auto
algo
=
get_algo_from_desc
(
execution_policy
().
algo
.
desc
))
{
if
(
auto
algo
=
static_cast
<
AlgoBase
*>
(
get_algorithm_from_desc
(
execution_policy
().
algo
.
desc
)))
{
megdnn_assert
(
algo
->
get_workspace
(
kern_size_param
)
<
workspace_limit_in_bytes
);
auto
cur
=
megdnn
::
get_reproducible_algo
<
MatrixMulImpl
>
(
algo
,
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -238,7 +238,8 @@ private:
class
AlgoPack
;
//! maintain all the algos of in the opr of fallback
static
const
AlgoPack
&
algo_pack
();
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
);
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
public:
/**
...
...
dnn/src/naive/batch_conv_bias/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
return
algo
;
}
BatchConvBiasForward
::
Algorithm
*
BatchConvBiasForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_batch_conv_bias_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
// vim: syntax=cpp.doxygen
dnn/src/naive/batch_conv_bias/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -39,6 +39,8 @@ public:
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
private:
WorkspaceBundle
get_workspace_bundle
(
dt_byte
*
raw_ptr
,
...
...
dnn/src/naive/batched_matrix_mul/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
->
default_batched_matmul_fwd_algo
();
}
BatchedMatrixMulForward
::
Algorithm
*
BatchedMatrixMulForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_batched_matmul_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
}
// namespace naive
}
// namespace megdnn
...
...
dnn/src/naive/batched_matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -34,6 +34,8 @@ public:
size_t
/*workspace_limit_in_bytes*/
,
bool
/* reproducible */
)
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
private:
...
...
dnn/src/naive/conv_bias/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
return
algo
;
}
ConvBiasForward
::
Algorithm
*
ConvBiasForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bias_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
const
char
*
ConvBiasForwardImpl
::
get_algorithm_set_name
()
const
{
return
"DEFAULT"
;
}
...
...
dnn/src/naive/conv_bias/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -64,6 +64,8 @@ public:
_megdnn_workspace
)
override
{}
const
char
*
get_algorithm_set_name
()
const
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
};
void
handle_z_inp_and_activation_naive
(
...
...
dnn/src/naive/convolution/convolution.cpp
浏览文件 @
85fa9883
...
...
@@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
return
algo
;
}
ConvolutionForward
::
Algorithm
*
ConvolutionForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
std
::
vector
<
ConvolutionBackwardData
::
Algorithm
*>
ConvolutionBackwardDataImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
...
...
@@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
return
algo
;
}
ConvolutionBackwardData
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bwd_data_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
std
::
vector
<
ConvolutionBackwardFilter
::
Algorithm
*>
ConvolutionBackwardFilterImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
...
...
@@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
return
algo
;
}
ConvolutionBackwardFilter
::
Algorithm
*
ConvolutionBackwardFilterImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bwd_filter_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
const
char
*
ConvolutionForwardImpl
::
get_algorithm_set_name
()
const
{
return
"DEFAULT"
;
}
...
...
dnn/src/naive/convolution/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward {
return
{};
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
;
};
...
...
@@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const
TensorLayout
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
};
class
ConvolutionBackwardFilterImpl
:
public
ConvolutionBackwardFilter
{
...
...
@@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
const
TensorLayout
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
};
}
// namespace naive
...
...
dnn/src/naive/convolution3d/convolution3d.cpp
浏览文件 @
85fa9883
...
...
@@ -6,15 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "./helper.h"
#include "./opr_impl.h"
#include "src/naive/handle.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cstring>
...
...
@@ -25,93 +25,95 @@ using namespace megdnn;
using
namespace
naive
;
void
Convolution3DForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_in
filter
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
MIDOUT_BEGIN
(
megdnn_naive_conv3d_fwd
)
{
auto
filter_meta
=
check_exec
(
src
.
layout
,
filter
.
layout
,
dst
.
layout
,
workspace
.
size
);
switch
(
param
().
data_type
)
{
case
Param
::
DataType
::
FLOAT
:
#define cb(dt) do {
\
if (src.layout.dtype == dt()) {
\
using ctype = DTypeTrait<dt>::ctype;
\
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()),
\
convolution3d::forward<
\
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>(
\
src, filter, dst, filter_meta);
\
);
\
return;
\
}
\
} while
(0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
auto
filter_meta
=
check_exec
(
src
.
layout
,
filter
.
layout
,
dst
.
layout
,
workspace
.
size
);
switch
(
param
().
data_type
)
{
case
Param
::
DataType
::
FLOAT
:
#define cb(dt) \
do {
\
if (src.layout.dtype == dt()) {
\
using ctype = DTypeTrait<dt>::ctype;
\
MEGDNN_DISPATCH_CPU_KERN(
\
static_cast<HandleImpl*>(handle()),
\
convolution3d::forward<
\
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>(
\
src, filter, dst, filter_meta););
\
return;
\
}
\
} while
(0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
#undef cb
break
;
case
Param
::
DataType
::
FLOAT_IO16xC32
:
MEGDNN_INC_FLOAT16
(
MEGDNN_DISPATCH_CPU_KERN
(
static_cast
<
HandleImpl
*>
(
handle
()),
convolution3d
::
forward
<
dt_float16
MEGDNN_COMMA
dt_float16
MEGDNN_COMMA
dt_float32
>
(
src
,
filter
,
dst
,
filter_meta
);));
return
;
break
;
case
Param
::
DataType
::
FLOAT_IO16xC32
:
MEGDNN_INC_FLOAT16
(
MEGDNN_DISPATCH_CPU_KERN
(
static_cast
<
HandleImpl
*>
(
handle
()),
convolution3d
::
forward
<
dt_float16
MEGDNN_COMMA
dt_float16
MEGDNN_COMMA
dt_float32
>
(
src
,
filter
,
dst
,
filter_meta
);));
return
;
}
megdnn_assert_internal
(
0
);
}
megdnn_assert_internal
(
0
);
}
MIDOUT_END
();
MIDOUT_END
();
}
void
Convolution3DBackwardDataImpl
::
exec
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
auto
filter_meta
=
check_exec
(
filter
.
layout
,
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
#define cb(dt) do { \
if (filter.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::backward_data< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
filter, diff, grad, filter_meta);); \
return; \
} \
} while(0);
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
auto
filter_meta
=
check_exec
(
filter
.
layout
,
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
#define cb(dt) \
do { \
if (filter.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::backward_data< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
filter, diff, grad, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
#undef cb
megdnn_assert_internal
(
0
);
}
void
Convolution3DBackwardFilterImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
auto
filter_meta
=
check_exec
(
src
.
layout
,
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
#define cb(dt) do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, diff, grad, filter_meta);); \
return; \
} \
} while(0);
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
auto
filter_meta
=
check_exec
(
src
.
layout
,
diff
.
layout
,
grad
.
layout
,
workspace
.
size
);
#define cb(dt) \
do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, diff, grad, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
#undef cb
megdnn_assert_internal
(
0
);
}
std
::
vector
<
Convolution3DForward
::
Algorithm
*>
Convolution3DForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_fwd_algo
()};
std
::
vector
<
Convolution3DForward
::
Algorithm
*>
Convolution3DForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_fwd_algo
()};
}
Convolution3DForward
::
Algorithm
*
...
...
@@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
return
algo
;
}
std
::
vector
<
Convolution3DBackwardData
::
Algorithm
*>
Convolution3DBackwardDataImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_data_algo
()};
Convolution3DForward
::
Algorithm
*
Convolution3DForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
std
::
vector
<
Convolution3DBackwardData
::
Algorithm
*>
Convolution3DBackwardDataImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_data_algo
()};
}
Convolution3DBackwardData
::
Algorithm
*
...
...
@@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
return
algo
;
}
std
::
vector
<
Convolution3DBackwardFilter
::
Algorithm
*>
Convolution3DBackwardFilterImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_filter_algo
()};
Convolution3DBackwardData
::
Algorithm
*
Convolution3DBackwardDataImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_data_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
std
::
vector
<
Convolution3DBackwardFilter
::
Algorithm
*>
Convolution3DBackwardFilterImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_filter_algo
()};
}
Convolution3DBackwardFilter
::
Algorithm
*
...
...
@@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
return
algo
;
}
Convolution3DBackwardFilter
::
Algorithm
*
Convolution3DBackwardFilterImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_filter_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
const
char
*
Convolution3DForwardImpl
::
get_algorithm_set_name
()
const
{
return
"DEFAULT"
;
}
...
...
dnn/src/naive/convolution3d/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -6,81 +6,79 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace
megdnn
{
namespace
naive
{
class
Convolution3DForwardImpl
:
public
Convolution3DForward
{
public:
using
Convolution3DForward
::
Convolution3DForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
const
char
*
get_algorithm_set_name
()
const
override
;
class
Convolution3DForwardImpl
:
public
Convolution3DForward
{
public:
using
Convolution3DForward
::
Convolution3DForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
;
};
class
Convolution3DBackwardDataImpl
:
public
Convolution3DBackwardData
{
public:
using
Convolution3DBackwardData
::
Convolution3DBackwardData
;
void
exec
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
class
Convolution3DBackwardDataImpl
:
public
Convolution3DBackwardData
{
public:
using
Convolution3DBackwardData
::
Convolution3DBackwardData
;
void
exec
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
const
char
*
get_algorithm_set_name
()
const
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
;
};
class
Convolution3DBackwardFilterImpl
:
public
Convolution3DBackwardFilter
{
public:
using
Convolution3DBackwardFilter
::
Convolution3DBackwardFilter
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
const
char
*
get_algorithm_set_name
()
const
override
;
class
Convolution3DBackwardFilterImpl
:
public
Convolution3DBackwardFilter
{
public:
using
Convolution3DBackwardFilter
::
Convolution3DBackwardFilter
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
;
};
}
// namespace naive
}
// namespace megdnn
// vim: syntax=cpp.doxygen
}
// namespace naive
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/naive/deformable_conv/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -48,6 +48,10 @@ public:
return
"DEFORMABLE_CONV2_NAIVE"
;
};
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
{
return
{};
}
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
mask
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
...
...
@@ -84,6 +88,10 @@ public:
return
"DEFORMABLE_CONV2_BWD_FILTER_NAIVE"
;
};
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
{
return
{};
}
void
exec
(
_megdnn_tensor_in
im
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
mask
,
_megdnn_tensor_in
out_grad
,
_megdnn_tensor_out
filter_grad
,
...
...
@@ -130,6 +138,10 @@ public:
return
"DEFORMABLE_CONV2_BWD_DATA_NAIVE"
;
};
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
{
return
{};
}
void
exec
(
_megdnn_tensor_in
im
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
mask
,
_megdnn_tensor_in
out_grad
,
_megdnn_tensor_out
im_grad
,
...
...
dnn/src/naive/local_share/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
return
algo
;
}
LocalShareForward
::
Algorithm
*
LocalShareForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
std
::
vector
<
LocalShareBackwardData
::
Algorithm
*>
LocalShareBackwardDataImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
...
...
@@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic(
return
algo
;
}
LocalShareBackwardData
::
Algorithm
*
LocalShareBackwardDataImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_data_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
std
::
vector
<
LocalShareBackwardFilter
::
Algorithm
*>
LocalShareBackwardFilterImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
...
...
@@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic(
return
algo
;
}
LocalShareBackwardFilter
::
Algorithm
*
LocalShareBackwardFilterImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_filter_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
// vim: syntax=cpp.doxygen
dnn/src/naive/local_share/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
...
...
@@ -35,6 +36,7 @@ public:
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
};
...
...
@@ -59,6 +61,7 @@ public:
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
};
...
...
@@ -83,6 +86,7 @@ public:
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
};
...
...
dnn/src/naive/matrix_mul/opr_impl.cpp
浏览文件 @
85fa9883
...
...
@@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
return
static_cast
<
HandleImpl
*>
(
handle
())
->
default_matmul_fwd_algo
();
}
MatrixMulForward
::
Algorithm
*
MatrixMulForwardImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_matmul_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
}
// namespace naive
}
// namespace megdnn
...
...
dnn/src/naive/matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -35,6 +35,8 @@ public:
size_t
/*workspace_limit_in_bytes*/
,
bool
/* reproducible */
)
override
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"DEFAULT"
;
}
private:
...
...
dnn/src/rocm/batched_matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -29,8 +29,8 @@ public:
class
AlgoBlas
;
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
);
private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
...
...
dnn/src/rocm/convolution/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -66,7 +66,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -112,7 +112,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -158,7 +158,7 @@ public:
class
AlgoPack
;
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
override
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
private:
...
...
dnn/src/rocm/matrix_mul/opr_impl.h
浏览文件 @
85fa9883
...
...
@@ -29,7 +29,7 @@ public:
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
return
sm_algo_pack
;
}
static
AlgoBase
*
get_algo_from_desc
(
const
AlgorithmDesc
&
desc
)
;
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
...
...
@@ -41,6 +41,7 @@ private:
const
TensorLayout
&
/*C*/
,
size_t
/*workspace_limit_in_bytes*/
,
bool
/*reproducible*/
)
override
;
const
char
*
get_algorithm_set_name
()
const
override
{
return
"ROCM MATMUL"
;
}
...
...
src/opr/test/dnn/convolution.cpp
浏览文件 @
85fa9883
...
...
@@ -2204,6 +2204,10 @@ public:
const
TensorLayout
&
p2
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
));
MOCK_METHOD1
(
get_algorithm_from_desc
,
Algorithm
*
(
const
AlgorithmDesc
&
));
protected:
const
char
*
get_algorithm_set_name
()
const
override
{
return
m_algorithm_set_name
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录