Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d69b5903
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d69b5903
编写于
9月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add an get_all_algorithms_safe interface
GitOrigin-RevId: e3734e4531c2f91205814ec6b490f5d6392f3fa3
上级
103d7f33
变更
67
隐藏空白更改
内联
并排
Showing
67 changed file
with
648 addition
and
35 deletion
+648
-35
dnn/include/megdnn/oprs/base.h
dnn/include/megdnn/oprs/base.h
+75
-4
dnn/src/arm_common/pooling/opr_impl.cpp
dnn/src/arm_common/pooling/opr_impl.cpp
+6
-1
dnn/src/arm_common/pooling/opr_impl.h
dnn/src/arm_common/pooling/opr_impl.h
+2
-0
dnn/src/common/algo_chooser.h
dnn/src/common/algo_chooser.h
+8
-2
dnn/src/cuda/batch_conv_bias/opr_impl.cpp
dnn/src/cuda/batch_conv_bias/opr_impl.cpp
+9
-0
dnn/src/cuda/batch_conv_bias/opr_impl.h
dnn/src/cuda/batch_conv_bias/opr_impl.h
+4
-0
dnn/src/cuda/batched_matrix_mul/opr_impl.cpp
dnn/src/cuda/batched_matrix_mul/opr_impl.cpp
+6
-0
dnn/src/cuda/batched_matrix_mul/opr_impl.h
dnn/src/cuda/batched_matrix_mul/opr_impl.h
+3
-0
dnn/src/cuda/conv_bias/opr_impl.cpp
dnn/src/cuda/conv_bias/opr_impl.cpp
+10
-0
dnn/src/cuda/conv_bias/opr_impl.h
dnn/src/cuda/conv_bias/opr_impl.h
+4
-0
dnn/src/cuda/convolution/opr_impl.cpp
dnn/src/cuda/convolution/opr_impl.cpp
+24
-0
dnn/src/cuda/convolution/opr_impl.h
dnn/src/cuda/convolution/opr_impl.h
+12
-0
dnn/src/cuda/convolution3d/opr_impl.cpp
dnn/src/cuda/convolution3d/opr_impl.cpp
+24
-0
dnn/src/cuda/convolution3d/opr_impl.h
dnn/src/cuda/convolution3d/opr_impl.h
+9
-0
dnn/src/cuda/deformable_conv/opr_impl.cpp
dnn/src/cuda/deformable_conv/opr_impl.cpp
+25
-0
dnn/src/cuda/deformable_conv/opr_impl.h
dnn/src/cuda/deformable_conv/opr_impl.h
+15
-0
dnn/src/cuda/local_share/opr_impl.cpp
dnn/src/cuda/local_share/opr_impl.cpp
+24
-1
dnn/src/cuda/local_share/opr_impl.h
dnn/src/cuda/local_share/opr_impl.h
+9
-0
dnn/src/cuda/matrix_mul/opr_impl.cpp
dnn/src/cuda/matrix_mul/opr_impl.cpp
+8
-0
dnn/src/cuda/matrix_mul/opr_impl.h
dnn/src/cuda/matrix_mul/opr_impl.h
+4
-0
dnn/src/cuda/pooling/opr_impl.cpp
dnn/src/cuda/pooling/opr_impl.cpp
+14
-0
dnn/src/cuda/pooling/opr_impl.h
dnn/src/cuda/pooling/opr_impl.h
+5
-0
dnn/src/fallback/batched_matrix_mul/opr_impl.cpp
dnn/src/fallback/batched_matrix_mul/opr_impl.cpp
+7
-0
dnn/src/fallback/batched_matrix_mul/opr_impl.h
dnn/src/fallback/batched_matrix_mul/opr_impl.h
+3
-0
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+8
-1
dnn/src/fallback/conv_bias/opr_impl.h
dnn/src/fallback/conv_bias/opr_impl.h
+4
-0
dnn/src/fallback/convolution/opr_impl.cpp
dnn/src/fallback/convolution/opr_impl.cpp
+18
-2
dnn/src/fallback/convolution/opr_impl.h
dnn/src/fallback/convolution/opr_impl.h
+7
-0
dnn/src/fallback/matrix_mul/opr_impl.cpp
dnn/src/fallback/matrix_mul/opr_impl.cpp
+7
-0
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+4
-0
dnn/src/naive/batch_conv_bias/opr_impl.cpp
dnn/src/naive/batch_conv_bias/opr_impl.cpp
+10
-0
dnn/src/naive/batch_conv_bias/opr_impl.h
dnn/src/naive/batch_conv_bias/opr_impl.h
+5
-0
dnn/src/naive/batched_matrix_mul/opr_impl.cpp
dnn/src/naive/batched_matrix_mul/opr_impl.cpp
+7
-1
dnn/src/naive/batched_matrix_mul/opr_impl.h
dnn/src/naive/batched_matrix_mul/opr_impl.h
+3
-0
dnn/src/naive/conv_bias/opr_impl.cpp
dnn/src/naive/conv_bias/opr_impl.cpp
+9
-0
dnn/src/naive/conv_bias/opr_impl.h
dnn/src/naive/conv_bias/opr_impl.h
+5
-0
dnn/src/naive/convolution/convolution.cpp
dnn/src/naive/convolution/convolution.cpp
+21
-0
dnn/src/naive/convolution/opr_impl.h
dnn/src/naive/convolution/opr_impl.h
+9
-0
dnn/src/naive/convolution3d/convolution3d.cpp
dnn/src/naive/convolution3d/convolution3d.cpp
+21
-1
dnn/src/naive/convolution3d/opr_impl.h
dnn/src/naive/convolution3d/opr_impl.h
+9
-0
dnn/src/naive/deformable_conv/opr_impl.h
dnn/src/naive/deformable_conv/opr_impl.h
+23
-0
dnn/src/naive/local_share/opr_impl.cpp
dnn/src/naive/local_share/opr_impl.cpp
+23
-0
dnn/src/naive/local_share/opr_impl.h
dnn/src/naive/local_share/opr_impl.h
+12
-1
dnn/src/naive/matrix_mul/opr_impl.cpp
dnn/src/naive/matrix_mul/opr_impl.cpp
+7
-0
dnn/src/naive/matrix_mul/opr_impl.h
dnn/src/naive/matrix_mul/opr_impl.h
+4
-0
dnn/src/naive/pooling/opr_impl.cpp
dnn/src/naive/pooling/opr_impl.cpp
+9
-0
dnn/src/naive/pooling/opr_impl.h
dnn/src/naive/pooling/opr_impl.h
+5
-0
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
+8
-0
dnn/src/rocm/batched_matrix_mul/opr_impl.h
dnn/src/rocm/batched_matrix_mul/opr_impl.h
+3
-0
dnn/src/rocm/convolution/opr_impl.cpp
dnn/src/rocm/convolution/opr_impl.cpp
+24
-0
dnn/src/rocm/convolution/opr_impl.h
dnn/src/rocm/convolution/opr_impl.h
+9
-0
dnn/src/rocm/matrix_mul/opr_impl.cpp
dnn/src/rocm/matrix_mul/opr_impl.cpp
+8
-0
dnn/src/rocm/matrix_mul/opr_impl.h
dnn/src/rocm/matrix_mul/opr_impl.h
+4
-0
dnn/src/rocm/pooling/opr_impl.cpp
dnn/src/rocm/pooling/opr_impl.cpp
+12
-1
dnn/src/rocm/pooling/opr_impl.h
dnn/src/rocm/pooling/opr_impl.h
+5
-0
dnn/src/x86/pooling/opr_impl.cpp
dnn/src/x86/pooling/opr_impl.cpp
+4
-1
dnn/src/x86/pooling/opr_impl.h
dnn/src/x86/pooling/opr_impl.h
+2
-0
dnn/test/common/accuracy_shake_checker.h
dnn/test/common/accuracy_shake_checker.h
+1
-1
dnn/test/common/benchmarker.h
dnn/test/common/benchmarker.h
+1
-1
dnn/test/common/checker.h
dnn/test/common/checker.h
+2
-2
dnn/test/common/convolution.cpp
dnn/test/common/convolution.cpp
+3
-3
dnn/test/common/opr_algo_proxy.h
dnn/test/common/opr_algo_proxy.h
+4
-4
dnn/test/common/opr_proxy.h
dnn/test/common/opr_proxy.h
+3
-3
dnn/test/cuda/cutlass_matmul.cpp
dnn/test/cuda/cutlass_matmul.cpp
+2
-2
src/core/test/graph/misc.cpp
src/core/test/graph/misc.cpp
+2
-2
src/opr/impl/search_policy/algo_chooser.cpp
src/opr/impl/search_policy/algo_chooser.cpp
+1
-1
src/opr/test/dnn/convolution.cpp
src/opr/test/dnn/convolution.cpp
+10
-0
未找到文件。
dnn/include/megdnn/oprs/base.h
浏览文件 @
d69b5903
...
...
@@ -315,7 +315,7 @@ public:
/*!
* \brief get a string representation for current algorithm set;
*
* get_all_algorithms() may return different algorithms only if
* get_all_algorithms
_safe
() may return different algorithms only if
* algorithm set name differs. This is used for checking cache
* validity.
*/
...
...
@@ -354,6 +354,15 @@ public:
return
ret
;
}
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
)
{
std
::
vector
<
AlgorithmInfo
>
ret
;
for
(
auto
&&
algo
:
get_all_algorithms_safe
(
p0
,
p1
))
{
ret
.
emplace_back
(
algo
->
info
());
}
return
ret
;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
...
...
@@ -378,6 +387,8 @@ protected:
//! get all possible algorithms for the specified layouts
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
)
=
0
;
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
)
=
0
;
/**
* \brief Returns the best algorithm by heuristic.
...
...
@@ -412,6 +423,16 @@ public:
return
ret
;
}
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
)
{
std
::
vector
<
AlgorithmInfo
>
ret
;
for
(
auto
&&
algo
:
get_all_algorithms_safe
(
p0
,
p1
,
p2
))
{
ret
.
emplace_back
(
algo
->
info
());
}
return
ret
;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
...
...
@@ -438,6 +459,9 @@ protected:
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
)
=
0
;
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
)
=
0
;
/**
* \brief Returns the best algorithm by heuristic.
...
...
@@ -463,7 +487,7 @@ public:
using
AlgoAttribute
=
detail
::
Algorithm
::
Attribute
;
//! get all possible algorithm decriptions for the specified layouts
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info
(
const
TensorLayout
&
p0
,
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
)
{
...
...
@@ -474,6 +498,17 @@ public:
return
ret
;
}
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
)
{
std
::
vector
<
AlgorithmInfo
>
ret
;
for
(
auto
&&
algo
:
get_all_algorithms_safe
(
p0
,
p1
,
p2
,
p3
))
{
ret
.
emplace_back
(
algo
->
info
());
}
return
ret
;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
...
...
@@ -500,6 +535,9 @@ protected:
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
)
=
0
;
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
)
=
0
;
/**
* \brief Returns the best algorithm by heuristic.
...
...
@@ -537,6 +575,18 @@ public:
return
ret
;
}
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
,
const
TensorLayout
&
p4
)
{
std
::
vector
<
AlgorithmInfo
>
ret
;
for
(
auto
&&
algo
:
get_all_algorithms_safe
(
p0
,
p1
,
p2
,
p3
,
p4
))
{
ret
.
emplace_back
(
algo
->
info
());
}
return
ret
;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
...
...
@@ -562,7 +612,11 @@ protected:
~
MultiAlgoOpr
()
=
default
;
//! get all possible algorithms for the specified layouts
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
,
const
TensorLayout
&
p4
)
=
0
;
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
,
const
TensorLayout
&
p4
)
=
0
;
...
...
@@ -604,6 +658,18 @@ public:
return
ret
;
}
std
::
vector
<
AlgorithmInfo
>
get_all_algorithms_info_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
,
const
TensorLayout
&
p4
,
const
TensorLayout
&
p5
,
const
TensorLayout
&
p6
,
const
TensorLayout
&
p7
)
{
std
::
vector
<
AlgorithmInfo
>
ret
;
for
(
auto
&&
algo
:
get_all_algorithms_safe
(
p0
,
p1
,
p2
,
p3
,
p4
,
p5
,
p6
,
p7
))
{
ret
.
emplace_back
(
algo
->
info
());
}
return
ret
;
}
/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
...
...
@@ -629,7 +695,12 @@ protected:
~
MultiAlgoOpr
()
=
default
;
//! get all possible algorithms for the specified layouts
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
,
const
TensorLayout
&
p4
,
const
TensorLayout
&
p5
,
const
TensorLayout
&
p6
,
const
TensorLayout
&
p7
)
=
0
;
virtual
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
const
TensorLayout
&
p3
,
const
TensorLayout
&
p4
,
const
TensorLayout
&
p5
,
...
...
dnn/src/arm_common/pooling/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -172,9 +172,14 @@ std::vector<Algorithm*> PoolingImpl::get_all_algorithms(
ret
.
push_back
(
i
);
}
}
megdnn_assert
(
!
ret
.
empty
(),
"no usable pooling fwd algorithm"
);
return
ret
;
}
std
::
vector
<
Algorithm
*>
PoolingImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
auto
ret_safe
=
get_all_algorithms
(
src
,
dst
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no usable pooling fwd algorithm"
);
return
ret_safe
;
}
Algorithm
*
PoolingImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
dnn/src/arm_common/pooling/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -131,6 +131,8 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
dnn/src/common/algo_chooser.h
浏览文件 @
d69b5903
...
...
@@ -100,10 +100,16 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms(
ret
.
push_back
(
i
);
}
}
megdnn_assert
(
!
ret
.
empty
(),
"no algorithm for %s"
,
args
.
to_string
().
c_str
());
return
ret
;
}
template
<
class
Opr
>
std
::
vector
<
typename
Opr
::
Algorithm
*>
get_all_algorithms_safe
(
const
typename
Opr
::
AlgoBase
::
SizeArgs
&
args
)
{
auto
ret_safe
=
get_all_algorithms
<
Opr
>
(
args
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no algorithm for %s"
,
args
.
to_string
().
c_str
());
return
ret_safe
;
}
/*!
* \brief a helper function to get an algorithm match attribute. If require a
...
...
dnn/src/cuda/batch_conv_bias/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -51,6 +51,15 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src,
AlgoBase
::
SizeArgs
args
{
this
,
src
,
filter
,
bias
,
z
,
dst
};
return
megdnn
::
get_all_algorithms
<
BatchConvBiasForwardImpl
>
(
args
);
}
std
::
vector
<
BatchConvBiasForwardImpl
::
Algorithm
*>
BatchConvBiasForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
{
AlgoBase
::
SizeArgs
args
{
this
,
src
,
filter
,
bias
,
z
,
dst
};
return
megdnn
::
get_all_algorithms_safe
<
BatchConvBiasForwardImpl
>
(
args
);
}
size_t
BatchConvBiasForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
dnn/src/cuda/batch_conv_bias/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -42,6 +42,10 @@ protected:
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
...
...
dnn/src/cuda/batched_matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -51,6 +51,12 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms(
}
return
ret
;
}
std
::
vector
<
Algorithm
*>
BatchedMatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
auto
ret_safe
=
get_all_algorithms
(
A
,
B
,
C
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no usable batchedmatrixmulForward fwd algorithm"
);
return
ret_safe
;
}
Algorithm
*
BatchedMatrixMulForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
...
...
dnn/src/cuda/batched_matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -45,6 +45,9 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
dnn/src/cuda/conv_bias/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -49,6 +49,16 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
filter
,
bias
,
z
,
dst
});
}
std
::
vector
<
ConvBiasForward
::
Algorithm
*>
ConvBiasForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms_safe
<
ConvBiasForwardImpl
>
(
{
this
,
src
,
filter
,
bias
,
z
,
dst
});
}
ConvBiasForward
::
Algorithm
*
ConvBiasForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
...
...
dnn/src/cuda/conv_bias/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -84,6 +84,10 @@ public:
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
...
...
dnn/src/cuda/convolution/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -53,6 +53,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
return
megdnn
::
get_all_algorithms
<
ConvolutionForwardImpl
>
(
args
);
}
std
::
vector
<
ConvolutionForwardImpl
::
Algorithm
*>
ConvolutionForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
AlgoBase
::
SizeArgs
args
{
this
,
src
,
filter
,
dst
};
return
megdnn
::
get_all_algorithms_safe
<
ConvolutionForwardImpl
>
(
args
);
}
size_t
ConvolutionForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
...
...
@@ -97,6 +105,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
{
this
,
filter
,
diff
,
grad
});
}
std
::
vector
<
ConvolutionBackwardDataImpl
::
Algorithm
*>
ConvolutionBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
ConvolutionBackwardDataImpl
>
(
{
this
,
filter
,
diff
,
grad
});
}
ConvolutionBackwardDataImpl
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
...
...
@@ -222,6 +238,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
diff
,
grad
});
}
std
::
vector
<
ConvolutionBackwardFilterImpl
::
Algorithm
*>
ConvolutionBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
ConvolutionBackwardFilterImpl
>
(
{
this
,
src
,
diff
,
grad
});
}
ConvolutionBackwardFilterImpl
::
Algorithm
*
ConvolutionBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
...
...
dnn/src/cuda/convolution/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -59,6 +59,10 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -111,6 +115,10 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -159,6 +167,10 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/cuda/convolution3d/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -108,6 +108,14 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
filter
,
dst
});
}
std
::
vector
<
Convolution3DForwardImpl
::
Algorithm
*>
Convolution3DForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms_safe
<
Convolution3DForwardImpl
>
(
{
this
,
src
,
filter
,
dst
});
}
size_t
Convolution3DForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
...
...
@@ -146,6 +154,14 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
{
this
,
filter
,
diff
,
grad
});
}
std
::
vector
<
Convolution3DBackwardDataImpl
::
Algorithm
*>
Convolution3DBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
Convolution3DBackwardDataImpl
>
(
{
this
,
filter
,
diff
,
grad
});
}
Convolution3DBackwardDataImpl
::
Algorithm
*
Convolution3DBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
...
...
@@ -226,6 +242,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
diff
,
grad
});
}
std
::
vector
<
Convolution3DBackwardFilterImpl
::
Algorithm
*>
Convolution3DBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
Convolution3DBackwardFilterImpl
>
(
{
this
,
src
,
diff
,
grad
});
}
Convolution3DBackwardFilterImpl
::
Algorithm
*
Convolution3DBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
...
...
dnn/src/cuda/convolution3d/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -39,6 +39,9 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -72,6 +75,9 @@ public:
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
...
...
@@ -109,6 +115,9 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/cuda/deformable_conv/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -51,6 +51,15 @@ std::vector<AlgoFwd*> Fwd::get_all_algorithms(const TensorLayout& /* im */,
return
algos
;
}
std
::
vector
<
AlgoFwd
*>
Fwd
::
get_all_algorithms_safe
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
dst
)
{
auto
ret_safe
=
Fwd
::
get_all_algorithms
(
im
,
filter
,
offset
,
mask
,
dst
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no usable deformable_conv fwd algorithm"
);
return
ret_safe
;
}
AlgoFwd
*
Fwd
::
get_algorithm_heuristic
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
...
...
@@ -115,6 +124,14 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */
return
algos
;
}
std
::
vector
<
AlgoBwdFlt
*>
BwdFlt
::
get_all_algorithms_safe
(
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
out_grad
,
const
TensorLayout
&
filter_grad
)
{
auto
ret_safe
=
BwdFlt
::
get_all_algorithms
(
im
,
offset
,
mask
,
out_grad
,
filter_grad
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no usable deformable_conv bwd filter algorithm"
);
return
ret_safe
;
}
AlgoBwdFlt
*
BwdFlt
::
get_algorithm_heuristic
(
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
out_grad
,
...
...
@@ -181,6 +198,14 @@ std::vector<AlgoBwdData*> BwdData::get_all_algorithms(
algos
.
push_back
(
static_cast
<
AlgoBwdData
*>
(
i
));
return
algos
;
}
std
::
vector
<
AlgoBwdData
*>
BwdData
::
get_all_algorithms_safe
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
out_grad
,
const
TensorLayout
&
im_grad
,
const
TensorLayout
&
offset_grad
,
const
TensorLayout
&
mask_grad
)
{
auto
ret_safe
=
BwdData
::
get_all_algorithms
(
im
,
filter
,
offset
,
mask
,
out_grad
,
im_grad
,
offset_grad
,
mask_grad
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no usable deformable_conv bwd data algorithm"
);
return
ret_safe
;
}
AlgoBwdData
*
BwdData
::
get_algorithm_heuristic
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
...
...
dnn/src/cuda/deformable_conv/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -54,6 +54,10 @@ protected:
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
...
...
@@ -105,6 +109,10 @@ protected:
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
out_grad
,
const
TensorLayout
&
filter_grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
out_grad
,
const
TensorLayout
&
filter_grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
...
...
@@ -161,6 +169,13 @@ protected:
const
TensorLayout
&
out_grad
,
const
TensorLayout
&
im_grad
,
const
TensorLayout
&
offset_grad
,
const
TensorLayout
&
mask_grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
mask
,
const
TensorLayout
&
out_grad
,
const
TensorLayout
&
im_grad
,
const
TensorLayout
&
offset_grad
,
const
TensorLayout
&
mask_grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
...
...
dnn/src/cuda/local_share/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -47,7 +47,6 @@ LocalShareForwardImpl::get_algorithm_heuristic(
Algorithm
::
attribute_str
(
positive_attr
).
c_str
(),
args
.
to_string
().
c_str
(),
workspace_limit_in_bytes
));
}
std
::
vector
<
LocalShareForwardImpl
::
Algorithm
*>
LocalShareForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
@@ -56,6 +55,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src,
return
megdnn
::
get_all_algorithms
<
LocalShareForwardImpl
>
(
args
);
}
std
::
vector
<
LocalShareForwardImpl
::
Algorithm
*>
LocalShareForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
AlgoBase
::
SizeArgs
args
{
this
,
src
,
filter
,
dst
};
return
megdnn
::
get_all_algorithms_safe
<
LocalShareForwardImpl
>
(
args
);
}
size_t
LocalShareForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
...
...
@@ -109,6 +116,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
return
megdnn
::
get_all_algorithms
<
LocalShareBackwardDataImpl
>
(
args
);
}
std
::
vector
<
LocalShareBackwardDataImpl
::
Algorithm
*>
LocalShareBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
AlgoBase
::
SizeArgs
args
{
this
,
filter
,
diff
,
grad
};
return
megdnn
::
get_all_algorithms_safe
<
LocalShareBackwardDataImpl
>
(
args
);
}
size_t
LocalShareBackwardDataImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
...
...
@@ -162,6 +177,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
return
megdnn
::
get_all_algorithms
<
LocalShareBackwardFilterImpl
>
(
args
);
}
std
::
vector
<
LocalShareBackwardFilterImpl
::
Algorithm
*>
LocalShareBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
AlgoBase
::
SizeArgs
args
{
this
,
src
,
diff
,
grad
};
return
megdnn
::
get_all_algorithms_safe
<
LocalShareBackwardFilterImpl
>
(
args
);
}
size_t
LocalShareBackwardFilterImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
...
...
dnn/src/cuda/local_share/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -37,6 +37,9 @@ public:
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
...
...
@@ -72,6 +75,9 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -105,6 +111,9 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/cuda/matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -28,6 +28,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
return
megdnn
::
get_all_algorithms
<
MatrixMulForwardImpl
>
(
args
);
}
std
::
vector
<
MatrixMulForwardImpl
::
Algorithm
*>
MatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
AlgoBase
::
SizeArgs
args
{
this
,
A
,
B
,
C
};
return
megdnn
::
get_all_algorithms_safe
<
MatrixMulForwardImpl
>
(
args
);
}
MatrixMulForwardImpl
::
Algorithm
*
MatrixMulForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
dnn/src/cuda/matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -60,6 +60,10 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
dnn/src/cuda/pooling/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -33,6 +33,11 @@ PoolingForwardImpl::get_all_algorithms(const TensorLayout& src,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms
<
PoolingForwardImpl
>
({
this
,
src
,
dst
});
}
std
::
vector
<
PoolingForwardImpl
::
Algorithm
*>
PoolingForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms_safe
<
PoolingForwardImpl
>
({
this
,
src
,
dst
});
}
PoolingForwardImpl
::
Algorithm
*
PoolingForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
@@ -77,6 +82,15 @@ PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
dst
,
diff
,
grad
});
}
std
::
vector
<
PoolingBackwardImpl
::
Algorithm
*>
PoolingBackwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
PoolingBackwardImpl
>
(
{
this
,
src
,
dst
,
diff
,
grad
});
}
PoolingBackwardImpl
::
Algorithm
*
PoolingBackwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
...
...
dnn/src/cuda/pooling/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -55,6 +55,8 @@ public:
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
@@ -99,6 +101,9 @@ protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
...
...
dnn/src/fallback/batched_matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -26,6 +26,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
AlgoBase
::
SizeArgs
args
{
this
,
A
,
B
,
C
};
return
megdnn
::
get_all_algorithms
<
BatchedMatrixMulForwardImpl
>
(
args
);
}
std
::
vector
<
BatchedMatrixMulForwardImpl
::
Algorithm
*>
BatchedMatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
AlgoBase
::
SizeArgs
args
{
this
,
A
,
B
,
C
};
return
megdnn
::
get_all_algorithms_safe
<
BatchedMatrixMulForwardImpl
>
(
args
);
}
BatchedMatrixMulForwardImpl
::
Algorithm
*
BatchedMatrixMulForwardImpl
::
get_algorithm_heuristic
(
...
...
dnn/src/fallback/batched_matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -35,6 +35,9 @@ private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -279,11 +279,18 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms(
auto
fparam
=
make_ncb_kern_size_param
(
src
,
filter
,
bias
,
dst
,
nullptr
);
auto
ret
=
get_all_algorithms_with_ncb
(
fparam
);
if
(
ret
.
empty
())
{
return
naive
::
ConvBiasForwardImpl
::
get_all_algorithms
(
src
,
filter
,
bias
,
return
naive
::
ConvBiasForwardImpl
::
get_all_algorithms
_safe
(
src
,
filter
,
bias
,
z
,
dst
);
}
return
ret
;
}
std
::
vector
<
ConvBiasImpl
::
Algorithm
*>
ConvBiasImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
{
auto
ret_safe
=
ConvBiasImpl
::
get_all_algorithms
(
src
,
filter
,
bias
,
z
,
dst
);
return
ret_safe
;
}
ConvBiasImpl
::
Algorithm
*
ConvBiasImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
dnn/src/fallback/conv_bias/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -87,6 +87,10 @@ public:
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
//! implemented by get_algorithm_heuristic_with_ncb()
Algorithm
*
get_algorithm_heuristic
(
...
...
dnn/src/fallback/convolution/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -198,12 +198,19 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
auto
fparam
=
make_ncb_kern_size_param
(
src
,
filter
,
dst
,
nullptr
);
auto
ret
=
get_all_algorithms_with_ncb
(
fparam
);
if
(
ret
.
empty
())
{
return
naive
::
ConvolutionForwardImpl
::
get_all_algorithms
(
src
,
filter
,
return
naive
::
ConvolutionForwardImpl
::
get_all_algorithms
_safe
(
src
,
filter
,
dst
);
}
return
ret
;
}
std
::
vector
<
ConvolutionImpl
::
Algorithm
*>
ConvolutionImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
auto
ret_safe
=
ConvolutionImpl
::
get_all_algorithms
(
src
,
filter
,
dst
);
return
ret_safe
;
}
ConvolutionImpl
::
Algorithm
*
ConvolutionImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -536,10 +543,19 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
}
auto
fparam
=
make_ncb_kern_size_param
(
filter
,
diff
,
grad
);
auto
ret
=
get_all_algorithms_with_ncb
(
fparam
);
megdnn_assert
(
!
ret
.
empty
(),
"no usable conv fwd algorithm"
);
return
ret
;
}
std
::
vector
<
ConvolutionBackwardDataImpl
::
Algorithm
*>
ConvolutionBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
auto
ret_safe
=
ConvolutionBackwardDataImpl
::
get_all_algorithms
(
filter
,
diff
,
grad
);
megdnn_assert
(
!
ret_safe
.
empty
(),
"no usable conv bwd algorithm"
);
return
ret_safe
;
}
ConvolutionBackwardDataImpl
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
...
...
dnn/src/fallback/convolution/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -85,6 +85,10 @@ public:
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
//! implemented by get_algorithm_heuristic_with_ncb()
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
@@ -326,6 +330,9 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/fallback/matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -96,6 +96,13 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
return
gemv_algos
;
}
std
::
vector
<
MatrixMul
::
Algorithm
*>
MatrixMulImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
auto
gemv_algos_safe
=
get_all_algorithms
(
A
,
B
,
C
);
megdnn_assert
(
!
gemv_algos_safe
.
empty
(),
"no usable MatrixMul fwd algorithm"
);
return
gemv_algos_safe
;
}
MatrixMulImpl
::
Algorithm
*
MatrixMulImpl
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
if
(
!
desc
.
valid
())
{
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -270,6 +270,10 @@ protected:
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
dnn/src/naive/batch_conv_bias/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -128,6 +128,16 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&,
->
default_batch_conv_bias_fwd_algo
()};
}
std
::
vector
<
BatchConvBiasForward
::
Algorithm
*>
BatchConvBiasForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_batch_conv_bias_fwd_algo
()};
}
BatchConvBiasForward
::
Algorithm
*
BatchConvBiasForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
...
...
dnn/src/naive/batch_conv_bias/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -30,6 +30,11 @@ public:
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
...
...
dnn/src/naive/batched_matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -63,7 +63,6 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A,
}
}
std
::
vector
<
BatchedMatrixMulForward
::
Algorithm
*>
BatchedMatrixMulForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
...
...
@@ -71,6 +70,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_batched_matmul_fwd_algo
()};
}
std
::
vector
<
BatchedMatrixMulForward
::
Algorithm
*>
BatchedMatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_batched_matmul_fwd_algo
()};
}
BatchedMatrixMulForward
::
Algorithm
*
BatchedMatrixMulForwardImpl
::
get_algorithm_heuristic
(
...
...
dnn/src/naive/batched_matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -27,6 +27,9 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
...
...
dnn/src/naive/conv_bias/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -321,6 +321,15 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout&,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bias_fwd_algo
()};
}
std
::
vector
<
ConvBiasForward
::
Algorithm
*>
ConvBiasForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bias_fwd_algo
()};
}
ConvBiasForward
::
Algorithm
*
ConvBiasForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* bias */
,
const
TensorLayout
&
/* z */
,
...
...
dnn/src/naive/conv_bias/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -31,6 +31,11 @@ public:
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
...
...
dnn/src/naive/convolution/convolution.cpp
浏览文件 @
d69b5903
...
...
@@ -287,6 +287,13 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_fwd_algo
()};
}
std
::
vector
<
ConvolutionForward
::
Algorithm
*>
ConvolutionForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_fwd_algo
()};
}
ConvolutionForward
::
Algorithm
*
ConvolutionForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* dst */
,
size_t
/* workspace_limit_in_bytes */
,
...
...
@@ -313,6 +320,13 @@ ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bwd_data_algo
()};
}
std
::
vector
<
ConvolutionBackwardData
::
Algorithm
*>
ConvolutionBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bwd_data_algo
()};
}
ConvolutionBackwardData
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* diff */
,
...
...
@@ -341,6 +355,13 @@ ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bwd_filter_algo
()};
}
std
::
vector
<
ConvolutionBackwardFilter
::
Algorithm
*>
ConvolutionBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv_bwd_filter_algo
()};
}
ConvolutionBackwardFilter
::
Algorithm
*
ConvolutionBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* diff */
,
...
...
dnn/src/naive/convolution/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -25,6 +25,9 @@ class ConvolutionForwardImpl: public ConvolutionForward {
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -67,6 +70,9 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -90,6 +96,9 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/naive/convolution3d/convolution3d.cpp
浏览文件 @
d69b5903
...
...
@@ -108,13 +108,18 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src,
megdnn_assert_internal
(
0
);
}
std
::
vector
<
Convolution3DForward
::
Algorithm
*>
Convolution3DForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_fwd_algo
()};
}
std
::
vector
<
Convolution3DForward
::
Algorithm
*>
Convolution3DForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_fwd_algo
()};
}
Convolution3DForward
::
Algorithm
*
Convolution3DForwardImpl
::
get_algorithm_heuristic
(
...
...
@@ -143,6 +148,13 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_data_algo
()};
}
std
::
vector
<
Convolution3DBackwardData
::
Algorithm
*>
Convolution3DBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_data_algo
()};
}
Convolution3DBackwardData
::
Algorithm
*
Convolution3DBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* diff */
,
...
...
@@ -172,6 +184,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
->
default_conv3d_bwd_filter_algo
()};
}
std
::
vector
<
Convolution3DBackwardFilter
::
Algorithm
*>
Convolution3DBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_conv3d_bwd_filter_algo
()};
}
Convolution3DBackwardFilter
::
Algorithm
*
Convolution3DBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* diff */
,
...
...
dnn/src/naive/convolution3d/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -22,6 +22,9 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -44,6 +47,9 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -66,6 +72,9 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/naive/deformable_conv/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -25,6 +25,12 @@ public:
const
TensorLayout
&
/* dst */
)
override
{
return
std
::
vector
<
Algorithm
*>
();
};
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
override
{
return
std
::
vector
<
Algorithm
*>
();
};
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
...
...
@@ -67,6 +73,13 @@ public:
const
TensorLayout
&
/* filter_grad */
)
override
{
return
std
::
vector
<
Algorithm
*>
();
};
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* filter_grad */
)
override
{
return
std
::
vector
<
Algorithm
*>
();
};
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
...
...
@@ -112,6 +125,16 @@ public:
return
std
::
vector
<
Algorithm
*>
();
};
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
)
override
{
return
std
::
vector
<
Algorithm
*>
();
};
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
...
...
dnn/src/naive/local_share/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -159,6 +159,13 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_fwd_algo
()};
}
std
::
vector
<
LocalShareForward
::
Algorithm
*>
LocalShareForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_fwd_algo
()};
}
LocalShareForward
::
Algorithm
*
LocalShareForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* diff */
,
const
TensorLayout
&
/* grad */
,
size_t
/* workspace_limit_in_bytes */
,
...
...
@@ -187,6 +194,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&,
->
default_local_share_bwd_data_algo
()};
}
std
::
vector
<
LocalShareBackwardData
::
Algorithm
*>
LocalShareBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_data_algo
()};
}
LocalShareBackwardData
::
Algorithm
*
LocalShareBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* diff */
,
...
...
@@ -216,6 +231,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
->
default_local_share_bwd_filter_algo
()};
}
std
::
vector
<
LocalShareBackwardFilter
::
Algorithm
*>
LocalShareBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_local_share_bwd_filter_algo
()};
}
LocalShareBackwardFilter
::
Algorithm
*
LocalShareBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* diff */
,
...
...
dnn/src/naive/local_share/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -30,6 +30,10 @@ public:
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*dst*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*dst*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*dst*/
,
size_t
/*workspace_limit_in_bytes*/
,
...
...
@@ -55,6 +59,10 @@ public:
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*filter*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
,
size_t
/*workspace_limit_in_bytes*/
,
...
...
@@ -75,11 +83,14 @@ public:
const
TensorLayout
&
)
override
{
return
0
;
}
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
,
size_t
/*workspace_limit_in_bytes*/
,
...
...
dnn/src/naive/matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -88,6 +88,13 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_matmul_fwd_algo
()};
}
std
::
vector
<
MatrixMulForward
::
Algorithm
*>
MatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_matmul_fwd_algo
()};
}
MatrixMulForward
::
Algorithm
*
MatrixMulForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
,
size_t
/*workspace_limit_in_bytes*/
,
...
...
dnn/src/naive/matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -29,6 +29,10 @@ public:
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
,
size_t
/*workspace_limit_in_bytes*/
,
...
...
dnn/src/naive/pooling/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -603,6 +603,10 @@ std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms(
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_pooling_fwd_algo
()};
}
std
::
vector
<
Algorithm
*>
PoolingForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
,
const
TensorLayout
&
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_pooling_fwd_algo
()};
}
Algorithm
*
PoolingForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*dst*/
,
...
...
@@ -626,6 +630,11 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms(
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_pooling_bwd_algo
()};
}
std
::
vector
<
Algorithm
*>
PoolingBackwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*dst*/
,
const
TensorLayout
&
/*diff*/
,
const
TensorLayout
&
/*grad*/
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_pooling_bwd_algo
()};
}
Algorithm
*
PoolingBackwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
/*src*/
,
const
TensorLayout
&
/*dst*/
,
...
...
dnn/src/naive/pooling/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -35,6 +35,8 @@ class PoolingForwardImpl: public PoolingForward {
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
@@ -60,6 +62,9 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -29,6 +29,14 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
return
megdnn
::
get_all_algorithms
<
BatchedMatrixMulForwardImpl
>
(
args
);
}
std
::
vector
<
BatchedMatrixMulForwardImpl
::
Algorithm
*>
BatchedMatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
AlgoBase
::
SizeArgs
args
{
this
,
A
,
B
,
C
};
return
megdnn
::
get_all_algorithms_safe
<
BatchedMatrixMulForwardImpl
>
(
args
);
}
BatchedMatrixMulForwardImpl
::
Algorithm
*
BatchedMatrixMulForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
...
...
dnn/src/rocm/batched_matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -35,6 +35,9 @@ private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
...
...
dnn/src/rocm/convolution/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -109,6 +109,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
filter
,
dst
});
}
std
::
vector
<
ConvolutionForwardImpl
::
Algorithm
*>
ConvolutionForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms_safe
<
ConvolutionForwardImpl
>
(
{
this
,
src
,
filter
,
dst
});
}
size_t
ConvolutionForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
const
PreprocessedFilter
*
)
{
...
...
@@ -162,6 +170,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
{
this
,
filter
,
diff
,
grad
});
}
std
::
vector
<
ConvolutionBackwardDataImpl
::
Algorithm
*>
ConvolutionBackwardDataImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
ConvolutionBackwardDataImpl
>
(
{
this
,
filter
,
diff
,
grad
});
}
ConvolutionBackwardDataImpl
::
Algorithm
*
ConvolutionBackwardDataImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
...
...
@@ -243,6 +259,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
{
this
,
src
,
diff
,
grad
});
}
std
::
vector
<
ConvolutionBackwardFilterImpl
::
Algorithm
*>
ConvolutionBackwardFilterImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
ConvolutionBackwardFilterImpl
>
(
{
this
,
src
,
diff
,
grad
});
}
ConvolutionBackwardFilterImpl
::
Algorithm
*
ConvolutionBackwardFilterImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
...
...
dnn/src/rocm/convolution/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -74,6 +74,9 @@ private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -123,6 +126,9 @@ private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
@@ -172,6 +178,9 @@ private:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_limit_in_bytes
,
...
...
dnn/src/rocm/matrix_mul/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -27,6 +27,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
return
megdnn
::
get_all_algorithms
<
MatrixMulForwardImpl
>
(
args
);
}
std
::
vector
<
MatrixMulForwardImpl
::
Algorithm
*>
MatrixMulForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
AlgoBase
::
SizeArgs
args
{
this
,
A
,
B
,
C
};
return
megdnn
::
get_all_algorithms_safe
<
MatrixMulForwardImpl
>
(
args
);
}
MatrixMulForwardImpl
::
Algorithm
*
MatrixMulForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
dnn/src/rocm/matrix_mul/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -36,6 +36,10 @@ private:
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/*A*/
,
const
TensorLayout
&
/*B*/
,
const
TensorLayout
&
/*C*/
,
size_t
/*workspace_limit_in_bytes*/
,
...
...
dnn/src/rocm/pooling/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -25,12 +25,16 @@ size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const
char
*
PoolingForwardImpl
::
get_algorithm_set_name
()
const
{
return
"ROCM_POOLING_FORWARD"
;
}
std
::
vector
<
PoolingForwardImpl
::
Algorithm
*>
PoolingForwardImpl
::
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms
<
PoolingForwardImpl
>
({
this
,
src
,
dst
});
}
std
::
vector
<
PoolingForwardImpl
::
Algorithm
*>
PoolingForwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms_safe
<
PoolingForwardImpl
>
({
this
,
src
,
dst
});
}
PoolingForwardImpl
::
Algorithm
*
PoolingForwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
@@ -82,6 +86,13 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms(
{
this
,
src
,
dst
,
diff
,
grad
});
}
std
::
vector
<
Algorithm
*>
PoolingBackwardImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
{
return
megdnn
::
get_all_algorithms_safe
<
PoolingBackwardImpl
>
(
{
this
,
src
,
dst
,
diff
,
grad
});
}
Algorithm
*
PoolingBackwardImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
...
...
dnn/src/rocm/pooling/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -46,6 +46,8 @@ class PoolingForwardImpl final: public PoolingForward {
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
@@ -93,6 +95,9 @@ class PoolingBackwardImpl final: public PoolingBackward {
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
...
...
dnn/src/x86/pooling/opr_impl.cpp
浏览文件 @
d69b5903
...
...
@@ -74,11 +74,14 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
return
fallback_worksapce
;
}
}
std
::
vector
<
Algorithm
*>
PoolingImpl
::
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms
<
PoolingImpl
>
({
this
,
src
,
dst
});
}
std
::
vector
<
Algorithm
*>
PoolingImpl
::
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
megdnn
::
get_all_algorithms_safe
<
PoolingImpl
>
({
this
,
src
,
dst
});
}
Algorithm
*
PoolingImpl
::
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
...
...
dnn/src/x86/pooling/opr_impl.h
浏览文件 @
d69b5903
...
...
@@ -63,6 +63,8 @@ public:
protected:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
override
;
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
size_t
workspace_limit_in_bytes
,
const
AlgoAttribute
&
positive_attr
,
...
...
dnn/test/common/accuracy_shake_checker.h
浏览文件 @
d69b5903
...
...
@@ -164,7 +164,7 @@ public:
}
std
::
vector
<
Algorithm
::
Info
::
Desc
>
ret
;
megdnn_assert
(
layouts
.
size
()
==
OprTrait
<
Opr
>::
arity
);
auto
vec
=
AlgoProxy
<
Opr
,
OprTrait
<
Opr
>::
arity
>::
get_all_algorithms_info
(
auto
vec
=
AlgoProxy
<
Opr
,
OprTrait
<
Opr
>::
arity
>::
get_all_algorithms_info
_safe
(
opr
,
layouts
);
for
(
auto
algo_info
:
vec
)
{
if
(
!
(
algo_info
.
attribute
&
...
...
dnn/test/common/benchmarker.h
浏览文件 @
d69b5903
...
...
@@ -377,7 +377,7 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts,
auto
opr
=
benchmark
.
opr
();
opr
->
param
()
=
benchmark
.
param
();
proxy
.
deduce_layout
(
opr
,
layouts
);
auto
algos
=
OprAlgoProxy
<
Opr
>::
get_all_algorithms_info
(
opr
,
layouts
);
auto
algos
=
OprAlgoProxy
<
Opr
>::
get_all_algorithms_info
_safe
(
opr
,
layouts
);
float
min_used
=
std
::
numeric_limits
<
float
>::
max
();
bool
execed
=
false
;
for
(
auto
i
:
algos
)
{
...
...
dnn/test/common/checker.h
浏览文件 @
d69b5903
...
...
@@ -514,7 +514,7 @@ struct ExecutionPolicyAlgoName {
* \brief a callable to check that given algorithm is used for heuristic
* \param require_algo if its value is true, then requires
* get_algorithm_heuristic() to return the expected algo; otherwise the
* expected algo must exist in get_all_algorithms() and it would be set to
* expected algo must exist in get_all_algorithms
_safe
() and it would be set to
* be used
*/
template
<
class
Opr
,
typename
OprAlgoProxy
=
OprAlgoProxy
<
Opr
>
>
...
...
@@ -536,7 +536,7 @@ public:
opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
Opr
::
Param
>
(
param
);
for
(
auto
algo_info
:
AlgoProxy
<
Opr
,
OprTrait
<
Opr
>::
arity
>::
get_all_algorithms_info
(
AlgoProxy
<
Opr
,
OprTrait
<
Opr
>::
arity
>::
get_all_algorithms_info
_safe
(
opr
.
get
(),
layouts
))
{
if
(
std
::
regex_match
(
algo_info
.
desc
.
name
,
...
...
dnn/test/common/convolution.cpp
浏览文件 @
d69b5903
...
...
@@ -695,7 +695,7 @@ Checker<Convolution> checker(handle);
float
scale
=
1.0
f
/
sqrt
(
fshp
[
channel_start
]
*
FH
*
FW
);
UniformFloatRNG
rng
(
scale
,
2
*
scale
);
checker
.
set_rng
(
0
,
&
rng
).
set_rng
(
1
,
&
rng
);
for
(
auto
algo
:
opr
->
get_all_algorithms_info
(
ily
,
fly
,
oly
))
{
for
(
auto
algo
:
opr
->
get_all_algorithms_info
_safe
(
ily
,
fly
,
oly
))
{
used_algos
.
insert
(
algo
.
desc
);
opr
->
execution_policy
().
algo
=
algo
.
desc
;
...
...
@@ -720,7 +720,7 @@ Checker<Convolution> checker(handle);
opr
->
param
()
=
param
;
std
::
string
param_str
;
Algorithm
::
serialize_write_pod
(
opr
->
param
(),
param_str
);
for
(
auto
algo
:
opr
->
get_all_algorithms_info
(
fly
,
oly
,
ily
))
{
for
(
auto
algo
:
opr
->
get_all_algorithms_info
_safe
(
fly
,
oly
,
ily
))
{
used_algos_bwd_data
.
insert
(
algo
.
desc
);
opr
->
execution_policy
().
algo
=
algo
.
desc
;
construct_sub_execution_policy_heuristic
<
...
...
@@ -747,7 +747,7 @@ Checker<Convolution> checker(handle);
opr
->
param
()
=
param
;
std
::
string
param_str
;
Algorithm
::
serialize_write_pod
(
opr
->
param
(),
param_str
);
for
(
auto
algo
:
opr
->
get_all_algorithms_info
(
ily
,
oly
,
fly
))
{
for
(
auto
algo
:
opr
->
get_all_algorithms_info
_safe
(
ily
,
oly
,
fly
))
{
used_algos_bwd_flt
.
insert
(
algo
.
desc
);
opr
->
execution_policy
().
algo
=
algo
.
desc
;
construct_sub_execution_policy_heuristic
<
...
...
dnn/test/common/opr_algo_proxy.h
浏览文件 @
d69b5903
...
...
@@ -25,9 +25,9 @@ struct AlgoProxy;
template <typename Opr> \
struct AlgoProxy<Opr, arity> { \
static std::vector<typename Opr::AlgorithmInfo> \
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \
get_all_algorithms_info
_safe
(Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_all_algorithms_info(LAYOUTS); \
return opr->get_all_algorithms_info
_safe
(LAYOUTS); \
} \
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
Opr* opr, const TensorLayoutArray& layouts) { \
...
...
@@ -80,9 +80,9 @@ DEF_ALGO_PROXY(8);
template <> \
struct AlgoProxy<Opr, arity> { \
static std::vector<typename Opr::AlgorithmInfo> \
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \
get_all_algorithms_info
_safe
(Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_all_algorithms_info(LAYOUTS); \
return opr->get_all_algorithms_info
_safe
(LAYOUTS); \
} \
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
Opr* opr, const TensorLayoutArray& layouts) { \
...
...
dnn/test/common/opr_proxy.h
浏览文件 @
d69b5903
...
...
@@ -288,7 +288,7 @@ struct OprProxyProfilingBase
Algorithm
::
deserialize_read_pod
<
typename
Opr
::
Param
>
(
param
);
std
::
vector
<
Algorithm
::
SearchItem
>
ret
;
for
(
auto
algo_info
:
AlgoProxy
<
Opr
,
arity
>::
get_all_algorithms_info
(
for
(
auto
algo_info
:
AlgoProxy
<
Opr
,
arity
>::
get_all_algorithms_info
_safe
(
opr
.
get
(),
layouts
))
{
Algorithm
*
algo
=
opr
->
get_algorithm_from_desc
(
algo_info
.
desc
);
std
::
vector
<
Algorithm
::
SearchItem
>&&
sub_items
=
...
...
@@ -367,7 +367,7 @@ struct OprProxyProfilingBase
megdnn_log
(
"Find best algo %s in cache"
,
algo
->
name
());
return
;
}
for
(
auto
algo
:
AlgoProxy
<
Opr
,
arity
>::
get_all_algorithms_info
(
for
(
auto
algo
:
AlgoProxy
<
Opr
,
arity
>::
get_all_algorithms_info
_safe
(
opr
.
get
(),
layouts
))
{
//! construct execution_policy
opr
->
execution_policy
().
algo
=
algo
.
desc
;
...
...
@@ -492,7 +492,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
if
(
Base
::
m_profiling
&&
!
Base
::
target_execution_policy
.
algo
.
valid
())
{
size_t
min_time
=
std
::
numeric_limits
<
size_t
>::
max
();
for
(
auto
algo
:
AlgoProxy
<
Opr
,
arity
>::
get_all_algorithms_info
(
opr
,
layouts
))
{
AlgoProxy
<
Opr
,
arity
>::
get_all_algorithms_info
_safe
(
opr
,
layouts
))
{
opr
->
execution_policy
().
algo
=
algo
.
desc
;
auto
preprocess_tensors
=
...
...
dnn/test/cuda/cutlass_matmul.cpp
浏览文件 @
d69b5903
...
...
@@ -84,7 +84,7 @@ void test_multibatchsize(
auto
opr_reference
=
handle_cuda
->
create_operator
<
MatrixMulForward
>
();
{
opr_reference
->
execution_policy
().
algo
.
reset
();
for
(
auto
i
:
opr_reference
->
get_all_algorithms_info
(
for
(
auto
i
:
opr_reference
->
get_all_algorithms_info
_safe
(
A_tensor
.
layout
(),
B_tensor
.
layout
(),
C_tensor
.
layout
()))
{
if
(
std
::
regex_match
(
...
...
@@ -113,7 +113,7 @@ void test_multibatchsize(
{{},
{},
C_tensor_prime
.
tensornd_host
()});
{
opr_reference
->
execution_policy
().
algo
.
reset
();
for
(
auto
i
:
opr_reference
->
get_all_algorithms_info
(
for
(
auto
i
:
opr_reference
->
get_all_algorithms_info
_safe
(
A_tensor_prime
.
layout
(),
B_tensor
.
layout
(),
C_tensor_batch
.
layout
()))
{
if
(
std
::
regex_match
(
...
...
src/core/test/graph/misc.cpp
浏览文件 @
d69b5903
...
...
@@ -1938,7 +1938,7 @@ typename megdnn::ExecutionPolicy try_find_any_weight_preprocess_algo(
return
{};
}
}
for
(
auto
&&
algo
:
dnn_op
->
get_all_algorithms_info
(
for
(
auto
&&
algo
:
dnn_op
->
get_all_algorithms_info
_safe
(
std
::
forward
<
Args
>
(
args
)...))
{
dnn_op
->
execution_policy
().
algo
=
algo
.
desc
;
auto
layouts
=
dnn_op
->
deduce_preprocessed_filter_layout
(
...
...
@@ -1972,7 +1972,7 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo(
return
{};
}
}
for
(
auto
&&
algo
:
dnn_op
->
get_all_algorithms_info
(
for
(
auto
&&
algo
:
dnn_op
->
get_all_algorithms_info
_safe
(
std
::
forward
<
Args
>
(
args
)...))
{
dnn_op
->
execution_policy
().
algo
=
algo
.
desc
;
auto
layouts
=
dnn_op
->
deduce_preprocessed_filter_layout
(
...
...
src/opr/impl/search_policy/algo_chooser.cpp
浏览文件 @
d69b5903
...
...
@@ -805,7 +805,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo>
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
get_all_candidates
()
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"get_all_candidates"
)))
auto
heu
=
choose_by_heuristic
(
m_execution_policy
.
strategy
);
auto
&&
ret
=
APPLY
(
m_dnn_opr
->
get_all_algorithms_info
(
args
...),
auto
&&
ret
=
APPLY
(
m_dnn_opr
->
get_all_algorithms_info
_safe
(
args
...),
m_fastrun_layouts
);
bool
found
=
false
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
...
...
src/opr/test/dnn/convolution.cpp
浏览文件 @
d69b5903
...
...
@@ -2473,6 +2473,11 @@ public:
std
::
vector
<
AlgorithmInfo
>
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
));
MOCK_METHOD3
(
get_all_algorithms_info_safe
,
std
::
vector
<
AlgorithmInfo
>
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
));
MOCK_METHOD6
(
get_algorithm_info_heuristic
,
AlgorithmInfo
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
...
...
@@ -2484,6 +2489,11 @@ public:
std
::
vector
<
Algorithm
*>
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
));
MOCK_METHOD3
(
get_all_algorithms_safe
,
std
::
vector
<
Algorithm
*>
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
));
MOCK_METHOD6
(
get_algorithm_heuristic
,
Algorithm
*
(
const
TensorLayout
&
p0
,
const
TensorLayout
&
p1
,
const
TensorLayout
&
p2
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录