Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a9c3515c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a9c3515c
编写于
3月 20, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(opr/naive): add DeformableConv algorithms interface
GitOrigin-RevId: adccb05f1a85552f7ab74aba5b2556675d5e2685
上级
d4bb54d4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
199 addition
and
30 deletion
+199
-30
dnn/src/naive/convolution/algorithms.h
dnn/src/naive/convolution/algorithms.h
+26
-0
dnn/src/naive/deformable_conv/opr_impl.cpp
dnn/src/naive/deformable_conv/opr_impl.cpp
+106
-0
dnn/src/naive/deformable_conv/opr_impl.h
dnn/src/naive/deformable_conv/opr_impl.h
+12
-30
dnn/src/naive/handle.cpp
dnn/src/naive/handle.cpp
+4
-0
dnn/src/naive/handle.h
dnn/src/naive/handle.h
+16
-0
dnn/test/naive/deformable_conv.cpp
dnn/test/naive/deformable_conv.cpp
+35
-0
未找到文件。
dnn/src/naive/convolution/algorithms.h
浏览文件 @
a9c3515c
...
@@ -63,6 +63,32 @@ class DefaultPoolingBackwardAlgorithm final
...
@@ -63,6 +63,32 @@ class DefaultPoolingBackwardAlgorithm final
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
};
class
DeformableConvForwardAlgorithm
final
:
public
megdnn
::
DeformableConvForward
::
Algorithm
{
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
|
AlgoAttribute
::
NAIVE
;
}
uint32_t
type
()
const
override
{
return
0
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
class
DeformableConvBackwardFilterAlgorithm
final
:
public
megdnn
::
DeformableConvBackwardFilter
::
Algorithm
{
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
|
AlgoAttribute
::
NAIVE
;
}
uint32_t
type
()
const
override
{
return
0
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
class
DeformableConvBackwardDataAlgorithm
final
:
public
megdnn
::
DeformableConvBackwardData
::
Algorithm
{
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
|
AlgoAttribute
::
NAIVE
;
}
uint32_t
type
()
const
override
{
return
0
;
}
const
char
*
name
()
const
override
{
return
"DEFAULT"
;
}
};
}
// namespace naive
}
// namespace naive
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/naive/deformable_conv/opr_impl.cpp
浏览文件 @
a9c3515c
#include "src/naive/deformable_conv/opr_impl.h"
#include "src/naive/deformable_conv/opr_impl.h"
#include <vector>
#include "src/common/utils.h"
#include "src/common/utils.h"
#include "src/naive/convolution/helper.h"
#include "src/naive/convolution/helper.h"
#include "src/naive/handle.h"
#include "src/naive/handle.h"
...
@@ -123,6 +124,38 @@ void Fwd::exec(
...
@@ -123,6 +124,38 @@ void Fwd::exec(
return
;
return
;
}
}
std
::
vector
<
DeformableConvForward
::
Algorithm
*>
Fwd
::
get_all_algorithms
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_fwd_algo
()};
}
std
::
vector
<
DeformableConvForward
::
Algorithm
*>
Fwd
::
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_fwd_algo
()};
}
DeformableConvForward
::
Algorithm
*
Fwd
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
,
size_t
/* workspace_limit_in_bytes */
,
const
AlgoAttribute
&
positive_attr
,
const
AlgoAttribute
&
negative_attr
)
{
auto
algo
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_fwd_algo
();
algo
->
check_attribute
(
positive_attr
,
negative_attr
);
return
algo
;
}
DeformableConvForward
::
Algorithm
*
Fwd
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_fwd_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
/* ============== Bwd Implementation ============== */
/* ============== Bwd Implementation ============== */
static
void
deformable_conv_backward_weight
(
static
void
deformable_conv_backward_weight
(
...
@@ -388,6 +421,41 @@ void BwdFlt::exec(
...
@@ -388,6 +421,41 @@ void BwdFlt::exec(
out_grad
.
ptr
<
float
>
(),
filter_grad
.
ptr
<
float
>
(),
OC
,
IC
,
N
,
FH
,
FW
,
IH
,
IW
,
out_grad
.
ptr
<
float
>
(),
filter_grad
.
ptr
<
float
>
(),
OC
,
IC
,
N
,
FH
,
FW
,
IH
,
IW
,
PH
,
PW
,
DH
,
DW
,
SH
,
SW
,
OH
,
OW
,
group
,
deformable_group
));
PH
,
PW
,
DH
,
DW
,
SH
,
SW
,
OH
,
OW
,
group
,
deformable_group
));
}
}
std
::
vector
<
BwdFlt
::
Algorithm
*>
BwdFlt
::
get_all_algorithms
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_filter_algo
()};
}
std
::
vector
<
BwdFlt
::
Algorithm
*>
BwdFlt
::
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_filter_algo
()};
}
BwdFlt
::
Algorithm
*
BwdFlt
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
,
size_t
/* workspace_limit_in_bytes */
,
const
AlgoAttribute
&
positive_attr
,
const
AlgoAttribute
&
negative_attr
)
{
auto
algo
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_filter_algo
();
algo
->
check_attribute
(
positive_attr
,
negative_attr
);
return
algo
;
}
BwdFlt
::
Algorithm
*
BwdFlt
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_filter_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
size_t
BwdData
::
get_workspace_in_bytes
(
size_t
BwdData
::
get_workspace_in_bytes
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
...
@@ -417,4 +485,42 @@ void BwdData::exec(
...
@@ -417,4 +485,42 @@ void BwdData::exec(
PH
,
PW
,
SH
,
SW
,
DH
,
DW
,
OH
,
OW
,
group
,
deformable_group
));
PH
,
PW
,
SH
,
SW
,
DH
,
DW
,
OH
,
OW
,
group
,
deformable_group
));
}
}
std
::
vector
<
BwdData
::
Algorithm
*>
BwdData
::
get_all_algorithms
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_data_algo
()};
}
std
::
vector
<
BwdData
::
Algorithm
*>
BwdData
::
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
)
{
return
{
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_data_algo
()};
}
BwdData
::
Algorithm
*
BwdData
::
get_algorithm_heuristic
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
,
size_t
/* workspace_limit_in_bytes */
,
const
AlgoAttribute
&
positive_attr
,
const
AlgoAttribute
&
negative_attr
)
{
auto
algo
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_data_algo
();
algo
->
check_attribute
(
positive_attr
,
negative_attr
);
return
algo
;
}
BwdData
::
Algorithm
*
BwdData
::
get_algorithm_from_desc
(
const
AlgorithmDesc
&
desc
)
{
Algorithm
*
ret
=
static_cast
<
HandleImpl
*>
(
handle
())
->
default_deformable_conv_bwd_data_algo
();
megdnn_assert
(
desc
==
ret
->
info
().
desc
);
return
ret
;
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/naive/deformable_conv/opr_impl.h
浏览文件 @
a9c3515c
...
@@ -12,24 +12,18 @@ public:
...
@@ -12,24 +12,18 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
override
{
const
TensorLayout
&
/* dst */
)
override
;
return
std
::
vector
<
Algorithm
*>
();
};
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
)
override
{
const
TensorLayout
&
/* dst */
)
override
;
return
std
::
vector
<
Algorithm
*>
();
};
Algorithm
*
get_algorithm_heuristic
(
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* dst */
,
size_t
/* workspace_limit_in_bytes */
,
const
TensorLayout
&
/* dst */
,
size_t
/* workspace_limit_in_bytes */
,
const
AlgoAttribute
&
/*positive_attr*/
,
const
AlgoAttribute
&
/*positive_attr*/
,
const
AlgoAttribute
&
/*negative_attr*/
)
override
{
const
AlgoAttribute
&
/*negative_attr*/
)
override
;
return
nullptr
;
};
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* src */
,
const
TensorLayout
&
/* filter */
,
...
@@ -42,7 +36,7 @@ public:
...
@@ -42,7 +36,7 @@ public:
return
"DEFORMABLE_CONV2_NAIVE"
;
return
"DEFORMABLE_CONV2_NAIVE"
;
};
};
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
{
return
{};
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
void
exec
(
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
offset
,
...
@@ -57,16 +51,12 @@ public:
...
@@ -57,16 +51,12 @@ public:
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
std
::
vector
<
Algorithm
*>
get_all_algorithms
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* filter_grad */
)
override
{
const
TensorLayout
&
/* filter_grad */
)
override
;
return
std
::
vector
<
Algorithm
*>
();
};
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* filter_grad */
)
override
{
const
TensorLayout
&
/* filter_grad */
)
override
;
return
std
::
vector
<
Algorithm
*>
();
};
Algorithm
*
get_algorithm_heuristic
(
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* offset */
,
...
@@ -74,9 +64,7 @@ public:
...
@@ -74,9 +64,7 @@ public:
const
TensorLayout
&
/* filter_grad */
,
const
TensorLayout
&
/* filter_grad */
,
size_t
/* workspace_limit_in_bytes */
,
size_t
/* workspace_limit_in_bytes */
,
const
AlgoAttribute
&
/*positive_attr*/
,
const
AlgoAttribute
&
/*positive_attr*/
,
const
AlgoAttribute
&
/*negative_attr*/
)
override
{
const
AlgoAttribute
&
/*negative_attr*/
)
override
;
return
nullptr
;
};
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
const
TensorLayout
&
im
,
const
TensorLayout
&
offset
,
...
@@ -87,7 +75,7 @@ public:
...
@@ -87,7 +75,7 @@ public:
return
"DEFORMABLE_CONV2_BWD_FILTER_NAIVE"
;
return
"DEFORMABLE_CONV2_BWD_FILTER_NAIVE"
;
};
};
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
{
return
{};
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
void
exec
(
void
exec
(
_megdnn_tensor_in
im
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
mask
,
_megdnn_tensor_in
im
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
mask
,
...
@@ -104,18 +92,14 @@ public:
...
@@ -104,18 +92,14 @@ public:
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
)
override
{
const
TensorLayout
&
/* mask_grad */
)
override
;
return
std
::
vector
<
Algorithm
*>
();
};
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
std
::
vector
<
Algorithm
*>
get_all_algorithms_safe
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* offset */
,
const
TensorLayout
&
/* mask */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* out_grad */
,
const
TensorLayout
&
/* im_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
)
override
{
const
TensorLayout
&
/* mask_grad */
)
override
;
return
std
::
vector
<
Algorithm
*>
();
};
Algorithm
*
get_algorithm_heuristic
(
Algorithm
*
get_algorithm_heuristic
(
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
const
TensorLayout
&
/* im */
,
const
TensorLayout
&
/* filter */
,
...
@@ -124,9 +108,7 @@ public:
...
@@ -124,9 +108,7 @@ public:
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
,
const
TensorLayout
&
/* offset_grad */
,
const
TensorLayout
&
/* mask_grad */
,
size_t
/* workspace_limit_in_bytes */
,
size_t
/* workspace_limit_in_bytes */
,
const
AlgoAttribute
&
/*positive_attr*/
,
const
AlgoAttribute
&
/*positive_attr*/
,
const
AlgoAttribute
&
/*negative_attr*/
)
override
{
const
AlgoAttribute
&
/*negative_attr*/
)
override
;
return
nullptr
;
};
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
im
,
const
TensorLayout
&
filter
,
...
@@ -138,7 +120,7 @@ public:
...
@@ -138,7 +120,7 @@ public:
return
"DEFORMABLE_CONV2_BWD_DATA_NAIVE"
;
return
"DEFORMABLE_CONV2_BWD_DATA_NAIVE"
;
};
};
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
{
return
{};
}
Algorithm
*
get_algorithm_from_desc
(
const
AlgorithmDesc
&
)
override
;
void
exec
(
void
exec
(
_megdnn_tensor_in
im
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
offset
,
_megdnn_tensor_in
im
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
offset
,
...
...
dnn/src/naive/handle.cpp
浏览文件 @
a9c3515c
...
@@ -115,6 +115,10 @@ DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo;
...
@@ -115,6 +115,10 @@ DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo;
DefaultPoolingForwardAlgorithm
HandleImpl
::
m_default_pooling_fwd_algo
;
DefaultPoolingForwardAlgorithm
HandleImpl
::
m_default_pooling_fwd_algo
;
DefaultPoolingBackwardAlgorithm
HandleImpl
::
m_default_pooling_bwd_algo
;
DefaultPoolingBackwardAlgorithm
HandleImpl
::
m_default_pooling_bwd_algo
;
DeformableConvForwardAlgorithm
HandleImpl
::
m_default_deformable_conv_fwd_algo
;
DeformableConvBackwardDataAlgorithm
HandleImpl
::
m_default_deformable_conv_bwd_data_algo
;
DeformableConvBackwardFilterAlgorithm
HandleImpl
::
m_default_deformable_conv_bwd_filter_algo
;
HandleImpl
::
HandleImpl
(
megcoreComputingHandle_t
computing_handle
,
HandleType
type
)
HandleImpl
::
HandleImpl
(
megcoreComputingHandle_t
computing_handle
,
HandleType
type
)
:
HandleImplHelper
(
computing_handle
,
type
),
:
HandleImplHelper
(
computing_handle
,
type
),
...
...
dnn/src/naive/handle.h
浏览文件 @
a9c3515c
...
@@ -38,6 +38,10 @@ class HandleImpl : public HandleImplHelper {
...
@@ -38,6 +38,10 @@ class HandleImpl : public HandleImplHelper {
static
DefaultPoolingForwardAlgorithm
m_default_pooling_fwd_algo
;
static
DefaultPoolingForwardAlgorithm
m_default_pooling_fwd_algo
;
static
DefaultPoolingBackwardAlgorithm
m_default_pooling_bwd_algo
;
static
DefaultPoolingBackwardAlgorithm
m_default_pooling_bwd_algo
;
static
DeformableConvForwardAlgorithm
m_default_deformable_conv_fwd_algo
;
static
DeformableConvBackwardDataAlgorithm
m_default_deformable_conv_bwd_data_algo
;
static
DeformableConvBackwardFilterAlgorithm
m_default_deformable_conv_bwd_filter_algo
;
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch
template
<
typename
T
>
template
<
typename
T
>
...
@@ -119,6 +123,18 @@ public:
...
@@ -119,6 +123,18 @@ public:
return
&
m_default_pooling_bwd_algo
;
return
&
m_default_pooling_bwd_algo
;
}
}
DeformableConvForward
::
Algorithm
*
default_deformable_conv_fwd_algo
()
{
return
&
m_default_deformable_conv_fwd_algo
;
}
DeformableConvBackwardData
::
Algorithm
*
default_deformable_conv_bwd_data_algo
()
{
return
&
m_default_deformable_conv_bwd_data_algo
;
}
DeformableConvBackwardFilter
::
Algorithm
*
default_deformable_conv_bwd_filter_algo
()
{
return
&
m_default_deformable_conv_bwd_filter_algo
;
}
Relayout
*
relayout_opr
()
override
{
return
get_helper_opr
<
Relayout
,
2
>
(
this
);
}
Relayout
*
relayout_opr
()
override
{
return
get_helper_opr
<
Relayout
,
2
>
(
this
);
}
/*!
/*!
* \brief pass a kernel to the dispatcher associated with the megcore
* \brief pass a kernel to the dispatcher associated with the megcore
...
...
dnn/test/naive/deformable_conv.cpp
浏览文件 @
a9c3515c
#include "megdnn/dtype.h"
#include "test/naive/fixture.h"
#include "test/naive/fixture.h"
#include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn.h"
...
@@ -52,6 +53,15 @@ TEST_F(NAIVE, DEFORMABLE_CONV_FWD) {
...
@@ -52,6 +53,15 @@ TEST_F(NAIVE, DEFORMABLE_CONV_FWD) {
{
1
,
2
*
2
*
3
*
3
,
5
,
5
},
{
1
,
2
*
2
*
3
*
3
,
5
,
5
},
{
1
,
2
*
3
*
3
,
5
,
5
},
{
1
,
2
*
3
*
3
,
5
,
5
},
{}});
{}});
//! check algo interface
auto
opr
=
handle
()
->
create_operator
<
DeformableConv
>
();
auto
i0
=
megdnn
::
TensorLayout
({
1
,
2
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i1
=
megdnn
::
TensorLayout
({
2
,
1
,
1
,
3
,
3
},
megdnn
::
dtype
::
Float32
());
auto
i2
=
megdnn
::
TensorLayout
({
1
,
2
*
2
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i3
=
megdnn
::
TensorLayout
({
1
,
2
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
o
=
opr
->
get_algorithm_info_heuristic
(
i0
,
i1
,
i2
,
i3
,
{});
auto
kk
=
o
.
desc
.
name
;
printf
(
"%s
\n
"
,
kk
.
c_str
());
}
}
TEST_F
(
NAIVE
,
DEFORMABLE_CONV_BWD_FILTER
)
{
TEST_F
(
NAIVE
,
DEFORMABLE_CONV_BWD_FILTER
)
{
...
@@ -82,6 +92,18 @@ TEST_F(NAIVE, DEFORMABLE_CONV_BWD_FILTER) {
...
@@ -82,6 +92,18 @@ TEST_F(NAIVE, DEFORMABLE_CONV_BWD_FILTER) {
{
1
,
2
*
3
*
3
,
5
,
5
},
{
1
,
2
*
3
*
3
,
5
,
5
},
{
1
,
2
,
5
,
5
},
{
1
,
2
,
5
,
5
},
{
2
,
1
,
1
,
3
,
3
}});
{
2
,
1
,
1
,
3
,
3
}});
//! check algo interface
auto
opr
=
handle
()
->
create_operator
<
DeformableConvBackwardFilter
>
();
auto
i0
=
megdnn
::
TensorLayout
({
1
,
2
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i1
=
megdnn
::
TensorLayout
({
1
,
2
*
2
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i2
=
megdnn
::
TensorLayout
({
1
,
2
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i3
=
megdnn
::
TensorLayout
({
1
,
2
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i4
=
megdnn
::
TensorLayout
({
2
,
1
,
1
,
3
,
3
},
megdnn
::
dtype
::
Float32
());
auto
o
=
opr
->
get_algorithm_info_heuristic
(
i0
,
i1
,
i2
,
i3
,
i4
);
auto
kk
=
o
.
desc
.
name
;
printf
(
"%s
\n
"
,
kk
.
c_str
());
}
}
TEST_F
(
NAIVE
,
DEFORMABLE_CONV_BWD_DATA
)
{
TEST_F
(
NAIVE
,
DEFORMABLE_CONV_BWD_DATA
)
{
...
@@ -118,5 +140,18 @@ TEST_F(NAIVE, DEFORMABLE_CONV_BWD_DATA) {
...
@@ -118,5 +140,18 @@ TEST_F(NAIVE, DEFORMABLE_CONV_BWD_DATA) {
{
1
,
2
,
5
,
5
},
{
1
,
2
,
5
,
5
},
{
1
,
1
*
2
*
3
*
3
,
5
,
5
},
{
1
,
1
*
2
*
3
*
3
,
5
,
5
},
{
1
,
1
*
3
*
3
,
5
,
5
}});
{
1
,
1
*
3
*
3
,
5
,
5
}});
//! check algo interface
auto
opr
=
handle
()
->
create_operator
<
DeformableConvBackwardData
>
();
auto
i0
=
megdnn
::
TensorLayout
({
1
,
2
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i1
=
megdnn
::
TensorLayout
({
2
,
1
,
1
,
3
,
3
},
megdnn
::
dtype
::
Float32
());
auto
i2
=
megdnn
::
TensorLayout
({
1
,
1
*
2
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i3
=
megdnn
::
TensorLayout
({
1
,
1
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i4
=
megdnn
::
TensorLayout
({
1
,
2
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i5
=
megdnn
::
TensorLayout
({
1
,
2
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i6
=
megdnn
::
TensorLayout
({
1
,
1
*
2
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
i7
=
megdnn
::
TensorLayout
({
1
,
1
*
3
*
3
,
5
,
5
},
megdnn
::
dtype
::
Float32
());
auto
o
=
opr
->
get_algorithm_info_heuristic
(
i0
,
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
);
auto
kk
=
o
.
desc
.
name
;
printf
(
"%s
\n
"
,
kk
.
c_str
());
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录