Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
543c9b77
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
543c9b77
编写于
7月 04, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add RegionRestrictedConv cuda
GitOrigin-RevId: b9f2d34a136d19590ca08ddcf7d167944d24aa6a
上级
fdec82ec
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
2077 addition
and
191 deletion
+2077
-191
dnn/src/common/region_restricted_convolution.cpp
dnn/src/common/region_restricted_convolution.cpp
+10
-1
dnn/src/cuda/handle_create.cpp
dnn/src/cuda/handle_create.cpp
+4
-0
dnn/src/cuda/region_restricted_convolution/chanwise/bwd_large_filter.cu
...egion_restricted_convolution/chanwise/bwd_large_filter.cu
+39
-0
dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh
...estricted_convolution/chanwise/depthwise_large_filter.cuh
+136
-0
dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh
...cted_convolution/chanwise/depthwise_large_filter_algo.cuh
+1186
-0
dnn/src/cuda/region_restricted_convolution/chanwise/fwd_large_filter.cu
...egion_restricted_convolution/chanwise/fwd_large_filter.cu
+41
-0
dnn/src/cuda/region_restricted_convolution/chanwise/kern.cuh
dnn/src/cuda/region_restricted_convolution/chanwise/kern.cuh
+57
-0
dnn/src/cuda/region_restricted_convolution/opr_impl.cpp
dnn/src/cuda/region_restricted_convolution/opr_impl.cpp
+79
-0
dnn/src/cuda/region_restricted_convolution/opr_impl.h
dnn/src/cuda/region_restricted_convolution/opr_impl.h
+55
-0
dnn/src/naive/convolution/helper.h
dnn/src/naive/convolution/helper.h
+19
-14
dnn/src/naive/region_restricted_convolution/opr_impl.cpp
dnn/src/naive/region_restricted_convolution/opr_impl.cpp
+96
-34
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+36
-36
dnn/test/cuda/region_restricted_convolution.cpp
dnn/test/cuda/region_restricted_convolution.cpp
+277
-0
dnn/test/naive/region_restricted_convolution.cpp
dnn/test/naive/region_restricted_convolution.cpp
+42
-106
未找到文件。
dnn/src/common/region_restricted_convolution.cpp
浏览文件 @
543c9b77
...
...
@@ -23,6 +23,7 @@ std::string get_errmsg(
"dilate_h="
+
std
::
to_string
(
param
.
dilate_h
)
+
", "
+
"dilate_w="
+
std
::
to_string
(
param
.
dilate_w
);
}
}
// namespace
namespace
megdnn
{
...
...
@@ -31,7 +32,12 @@ void RegionRestrictedConvolutionForward::deduce_dtype(
DType
src
,
DType
filter
,
DType
rin
,
DType
rout
,
DType
&
dst
)
{
check_or_deduce_dtype_fwd
(
src
,
filter
,
dst
);
megdnn_assert
(
rin
==
rout
&&
rin
==
dtype
::
Int32
(),
src
.
category
()
==
DTypeCategory
::
FLOAT
&&
filter
.
category
()
==
DTypeCategory
::
FLOAT
&&
dst
.
category
()
==
DTypeCategory
::
FLOAT
,
"only float type is supported for region_restricted_conv forward"
);
megdnn_assert
(
rin
==
rout
&&
(
rin
==
dtype
::
Int32
()
||
rin
==
dtype
::
Uint8
()),
"the dtype of rin/rout should be Int32, got %s."
,
rin
.
name
());
}
...
...
@@ -51,6 +57,9 @@ RegionRestrictedConvolutionForward::check_exec(
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW
,
"RegionRestrictedConv only support NCHW format mow."
);
megdnn_assert
(
param
().
stride_h
==
1
&&
param
().
stride_w
==
1
,
"RegionRestrictedConv only support stride 1."
);
#define err_msg(lhs, rhs) \
megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs);
...
...
dnn/src/cuda/handle_create.cpp
浏览文件 @
543c9b77
...
...
@@ -53,6 +53,7 @@
#include "src/cuda/pooling/opr_impl.h"
#include "src/cuda/powc/opr_impl.h"
#include "src/cuda/reduce/opr_impl.h"
#include "src/cuda/region_restricted_convolution/opr_impl.h"
#include "src/cuda/relayout/opr_impl.h"
#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/remap/opr_impl.h"
...
...
@@ -218,6 +219,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
SoftmaxForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
SoftmaxBackward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
NormForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
RegionRestrictedConvolutionForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
RegionRestrictedConvolutionBackwardData
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
RegionRestrictedConvolutionBackwardFilter
);
template
<
typename
Opr
>
std
::
unique_ptr
<
Opr
>
HandleImpl
::
create_operator
()
{
...
...
dnn/src/cuda/region_restricted_convolution/chanwise/bwd_large_filter.cu
0 → 100644
浏览文件 @
543c9b77
#include "./kern.cuh"
#include "cuda.h"
#include "cuda_fp16.h"
#include "src/cuda/fp16_help.cuh"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
region_restricted_convolution
;
using
namespace
chanwise
;
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
region_restricted_convolution
{
namespace
chanwise
{
// =====================================fwd=====================================
template
<
>
void
run_bwd_depthwise_large_filter
(
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
int
*
rin
,
const
int
*
rout
,
const
Param
&
param
,
cudaStream_t
stream
)
{
INSTANCE_INT
(
float
,
int
,
DepthwiseConv2dDirection
::
DIRECTION_BACKWARD
)
}
template
<
>
void
run_bwd_depthwise_large_filter
(
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
uint8_t
*
rin
,
const
uint8_t
*
rout
,
const
Param
&
param
,
cudaStream_t
stream
)
{
INSTANCE_UINT8
(
float
,
uint8_t
,
DepthwiseConv2dDirection
::
DIRECTION_BACKWARD
)
}
}
// namespace chanwise
}
// namespace region_restricted_convolution
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cuda.doxygen
dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh
0 → 100644
浏览文件 @
543c9b77
#pragma once
namespace
{
#define DIVUP(x, y) (((x) + (y)-1) / (y))
enum
DepthwiseConv2dDirection
{
DIRECTION_FORWARD
,
DIRECTION_BACKWARD
};
template
<
typename
ThreadConfig_
,
int
oh_
,
int
ow_
>
struct
OutTileConfig
{
using
ThreadConfig
=
ThreadConfig_
;
static
int
constexpr
unroll_h
=
oh_
;
static
int
constexpr
unroll_w
=
ThreadConfig
::
thread_x
*
ow_
;
static
int
constexpr
unroll_size
=
unroll_h
*
unroll_w
;
static
int
constexpr
block_h
=
unroll_h
*
ThreadConfig
::
thread_y
;
static
int
constexpr
block_w
=
unroll_w
;
};
template
<
int
fh_
,
int
fw_
>
struct
FilterTileConfig
{
static
int
constexpr
unroll_h
=
fh_
;
static
int
constexpr
unroll_w
=
fw_
;
static
int
constexpr
unroll_size
=
unroll_h
*
unroll_w
;
};
template
<
int
x_
,
int
y_
>
struct
ThreadConfig
{
static
int
constexpr
thread_x
=
x_
;
static_assert
((
thread_x
&
(
thread_x
-
1
))
==
0
,
"thread_x must be pow of 2!"
);
static
int
constexpr
thread_y
=
y_
;
static
int
constexpr
nr_threads
=
x_
*
y_
;
};
template
<
typename
ldg_dtype
,
typename
Rldg_dtype
,
typename
Rcmp_dtype
,
typename
ThreadConfig_
,
typename
OutTileConfig_
,
typename
FilterTileConfig_
,
int
stride_w
,
int
stride_h
>
struct
ConvTraitInner
{
using
ThreadConfig
=
ThreadConfig_
;
using
OutTileConfig
=
OutTileConfig_
;
using
FilterTileConfig
=
FilterTileConfig_
;
using
CompType
=
ldg_dtype
;
struct
SrcTileConfig
{
static
int
constexpr
unroll_h
=
OutTileConfig
::
unroll_h
+
FilterTileConfig
::
unroll_h
-
1
;
static
int
constexpr
unroll_w
=
(
OutTileConfig
::
unroll_w
-
1
)
*
stride_w
+
FilterTileConfig
::
unroll_w
;
static
int
constexpr
unroll_size
=
unroll_h
*
unroll_w
;
};
struct
SrcTileCount
{
static
int
constexpr
smem_src_h
=
(
OutTileConfig
::
block_h
-
1
)
*
stride_h
+
FilterTileConfig
::
unroll_h
;
static
int
constexpr
smem_delta_h
=
2
;
static
int
constexpr
smem_buff_h
=
FilterTileConfig
::
unroll_h
*
smem_delta_h
*
2
;
static
int
constexpr
smem_load_h
=
smem_src_h
+
smem_buff_h
;
static
int
constexpr
smem_h
=
smem_load_h
;
static
int
constexpr
smem_w
=
DIVUP
((
OutTileConfig
::
block_w
-
1
)
*
stride_w
+
FilterTileConfig
::
unroll_w
*
ThreadConfig
::
thread_x
,
2
)
*
2
;
static
int
constexpr
load_w
=
smem_w
>
32
?
32
:
smem_w
;
static
int
constexpr
load_h
=
ThreadConfig
::
nr_threads
/
load_w
;
static
int
constexpr
reg_h
=
DIVUP
(
smem_delta_h
,
load_h
);
static
int
constexpr
reg_w
=
DIVUP
(
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
smem_delta_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
smem_w
%
load_w
!=
0
;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static
int
constexpr
bank_w
=
smem_w
/
(
4
/
sizeof
(
CompType
));
static
int
constexpr
bank_offset_line
=
(
bank_w
%
32
==
0
||
bank_w
%
FilterTileConfig
::
unroll_w
==
0
)
?
1
:
(
bank_w
%
16
==
0
?
2
:
4
);
static
int
constexpr
smem_size
=
smem_h
*
smem_w
+
DIVUP
(
smem_h
,
bank_offset_line
)
*
(
4
/
sizeof
(
CompType
));
};
struct
FilterTileCount
{
static
int
constexpr
smem_flt_h
=
FilterTileConfig
::
unroll_h
;
static
int
constexpr
smem_buff_h
=
FilterTileConfig
::
unroll_h
;
static
int
constexpr
smem_w
=
FilterTileConfig
::
unroll_w
*
ThreadConfig
::
thread_x
;
static
int
constexpr
smem_delta_h
=
2
;
static
int
constexpr
smem_load_h
=
smem_flt_h
+
smem_buff_h
*
smem_w
;
static
int
constexpr
smem_h
=
smem_load_h
+
smem_buff_h
;
static
int
constexpr
load_w
=
smem_w
>
32
?
32
:
smem_w
;
static
int
constexpr
load_h
=
ThreadConfig
::
nr_threads
/
load_w
;
static
int
constexpr
reg_h
=
1
;
static
int
constexpr
reg_w
=
DIVUP
(
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
smem_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
smem_w
%
load_w
!=
0
;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static
int
constexpr
bank_w
=
smem_w
/
(
4
/
sizeof
(
CompType
));
static
int
constexpr
bank_offset_line
=
(
bank_w
%
32
==
0
||
bank_w
%
FilterTileConfig
::
unroll_w
==
0
)
?
1
:
(
bank_w
%
16
==
0
?
2
:
4
);
static
int
constexpr
smem_size
=
smem_h
*
smem_w
+
DIVUP
(
smem_h
,
bank_offset_line
)
*
(
4
/
sizeof
(
CompType
));
};
struct
RinTileCount
{
static
int
constexpr
smem_src_h
=
(
OutTileConfig
::
block_h
-
1
)
*
stride_h
+
FilterTileConfig
::
unroll_h
;
static
int
constexpr
smem_delta_h
=
2
;
static
int
constexpr
smem_buff_h
=
FilterTileConfig
::
unroll_h
*
smem_delta_h
*
2
;
static
int
constexpr
smem_load_h
=
smem_src_h
+
smem_buff_h
;
static
int
constexpr
smem_h
=
smem_load_h
;
static
int
constexpr
factor
=
sizeof
(
Rldg_dtype
)
/
sizeof
(
Rcmp_dtype
);
static
int
constexpr
smem_w
=
DIVUP
(
DIVUP
((
OutTileConfig
::
block_w
-
1
)
*
stride_w
+
FilterTileConfig
::
unroll_w
*
ThreadConfig
::
thread_x
,
factor
),
2
)
*
2
;
static
int
constexpr
load_w
=
smem_w
>
32
?
32
:
smem_w
;
static
int
constexpr
load_h
=
ThreadConfig
::
nr_threads
/
load_w
;
static
int
constexpr
reg_h
=
DIVUP
(
smem_delta_h
,
load_h
);
static
int
constexpr
reg_w
=
DIVUP
(
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
smem_delta_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
smem_w
%
load_w
!=
0
;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static
int
constexpr
bank_w
=
smem_w
;
static
int
constexpr
bank_offset_line
=
(
bank_w
%
32
==
0
||
bank_w
%
FilterTileConfig
::
unroll_w
==
0
)
?
1
:
(
bank_w
%
16
==
0
?
2
:
4
);
static
int
constexpr
smem_size
=
smem_h
*
smem_w
+
DIVUP
(
smem_h
,
bank_offset_line
);
};
};
}
// namespace
dnn/src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh
0 → 100644
浏览文件 @
543c9b77
#pragma once
#include "depthwise_large_filter.cuh"
#include "src/cuda/cuda_shfl_compat.cuh"
namespace
{
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
struct
Global2SharedMem
{
using
TileCount
=
TileCount_
;
using
ThreadConfig
=
ThreadConfig_
;
T
reg
[
TileCount
::
reg_h
][
TileCount
::
reg_w
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
tid
=
tidy
*
ThreadConfig
::
thread_x
+
tidx
;
const
int
gl_load_y
=
tid
/
TileCount
::
load_w
;
const
int
gl_load_x
=
tid
-
gl_load_y
*
TileCount
::
load_w
;
const
bool
is_fwd
=
(
kDirection
==
DIRECTION_FORWARD
);
int
w_offset
;
T
*
smem
;
int
stride
;
int
start_h
,
start_w
,
bound_h
,
bound_w
,
ring_smem_h
,
ring_src_h
;
// just used in backward src data
int
stride_h
,
stride_w
;
const
RT
*
g_ptr
;
__device__
__forceinline__
Global2SharedMem
(
T
*
smem_
,
int
stride_
,
int
s_h
,
int
s_w
,
int
b_h
,
int
b_w
,
int
stride_h_
,
int
stride_w_
);
__device__
__forceinline__
void
first_copy
();
__device__
__forceinline__
void
copy
();
__device__
__forceinline__
void
commit
();
__device__
__forceinline__
void
iter_forward
();
__device__
__forceinline__
T
*
sh_ptr
(
int
y
,
int
x
)
{
return
&
smem
[
y
*
TileCount
::
smem_w
+
x
];
}
__device__
__forceinline__
T
*
sh_ptr_as_copy_t
(
int
y
,
int
x
)
{
return
reinterpret_cast
<
T
*>
(
sh_ptr
(
y
,
x
));
}
};
template
<
typename
ldg_dtype
,
typename
Rldg_dtype
,
typename
Rcmp_dtype
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
OutTileConfig_
,
typename
FilterTileConfig_
,
int
stride_w
,
int
stride_h
>
struct
ConvTrait
{
using
ThreadConfig
=
ThreadConfig_
;
using
OutTileConfig
=
OutTileConfig_
;
using
FilterTileConfig
=
FilterTileConfig_
;
using
CompType
=
ldg_dtype
;
using
RLdgType
=
Rldg_dtype
;
using
RCmpType
=
Rcmp_dtype
;
using
CI
=
ConvTraitInner
<
ldg_dtype
,
Rldg_dtype
,
Rcmp_dtype
,
ThreadConfig_
,
OutTileConfig_
,
FilterTileConfig_
,
stride_w
,
stride_h
>
;
using
SrcTileConfig
=
typename
CI
::
SrcTileConfig
;
using
SrcTileCount
=
typename
CI
::
SrcTileCount
;
using
FilterTileCount
=
typename
CI
::
FilterTileCount
;
using
RinTileCount
=
typename
CI
::
RinTileCount
;
using
SrcGlobal2ShareVisitor
=
Global2SharedMem
<
CompType
,
CompType
,
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
,
ThreadConfig
,
SrcTileCount
>
;
using
RinGlobal2ShareVisitor
=
Global2SharedMem
<
Rldg_dtype
,
Rcmp_dtype
,
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
,
ThreadConfig
,
RinTileCount
>
;
using
FilterGlobal2ShareVisitor
=
Global2SharedMem
<
CompType
,
CompType
,
kDirection
,
ThreadConfig
,
FilterTileCount
>
;
};
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
Global2SharedMem
<
T
,
RT
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
Global2SharedMem
(
T
*
smem_
,
int
stride_
,
int
s_h
,
int
s_w
,
int
b_h
,
int
b_w
,
int
stride_h_
,
int
stride_w_
)
:
smem
(
smem_
),
stride
(
stride_
),
start_h
(
s_h
),
start_w
(
s_w
),
bound_h
(
b_h
),
bound_w
(
b_w
),
ring_smem_h
(
TileCount
::
smem_load_h
),
stride_h
(
stride_h_
),
stride_w
(
stride_w_
)
{
if
(
is_fwd
)
{
ring_src_h
=
s_h
+
TileCount
::
smem_load_h
;
w_offset
=
0
;
}
else
{
ring_src_h
=
s_h
-
1
;
w_offset
=
TileCount
::
smem_w
-
b_w
;
// stride_h and stride_w just used in backward src data.
stride_h
=
stride_w
=
1
;
}
}
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
RT
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
first_copy
()
{
static
int
const
load_w
=
TileCount
::
smem_w
>
32
?
32
:
TileCount
::
smem_w
;
static
int
const
load_h
=
ThreadConfig
::
nr_threads
/
load_w
;
static
int
const
h_per_thread
=
DIVUP
(
TileCount
::
smem_load_h
,
load_h
);
static
int
const
w_per_thread
=
DIVUP
(
TileCount
::
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
TileCount
::
smem_load_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
TileCount
::
smem_w
%
load_w
!=
0
;
const
int
y_base_idx
=
tid
/
load_w
;
const
int
x_base_idx
=
tid
-
y_base_idx
*
load_w
;
#pragma unroll
for
(
int
i
=
0
;
i
<
h_per_thread
;
++
i
)
{
int
smem_h_idx
=
y_base_idx
+
i
*
load_h
;
int
bank_offset
=
smem_h_idx
/
TileCount
::
bank_offset_line
;
int
src_h_idx
;
if
(
is_fwd
)
{
src_h_idx
=
start_h
+
smem_h_idx
;
}
else
{
src_h_idx
=
start_h
-
smem_h_idx
;
}
if
(
check_bounds_h
&&
smem_h_idx
>=
TileCount
::
smem_load_h
)
continue
;
#pragma unroll
for
(
int
j
=
0
;
j
<
w_per_thread
;
++
j
)
{
int
smem_w_idx
=
x_base_idx
+
j
*
load_w
;
int
src_w_idx
;
if
(
is_fwd
)
{
src_w_idx
=
start_w
+
smem_w_idx
;
}
else
{
src_w_idx
=
start_w
+
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
;
}
if
(
check_bounds_w
&&
smem_w_idx
>=
TileCount
::
smem_w
)
continue
;
T
val
=
0.0
f
;
if
(
src_h_idx
>=
0
&&
src_h_idx
<
bound_h
&&
src_w_idx
>=
0
&&
src_w_idx
<
bound_w
&&
((
is_fwd
&&
src_h_idx
%
stride_h
==
0
&&
src_w_idx
%
stride_w
==
0
)
||
(
!
is_fwd
&&
TileCount
::
smem_load_h
-
smem_h_idx
-
1
>=
0
&&
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
>=
0
)))
{
val
=
g_ptr
[
src_h_idx
/
stride_h
*
stride
+
src_w_idx
/
stride_w
];
}
*
(
sh_ptr_as_copy_t
(
smem_h_idx
,
smem_w_idx
+
bank_offset
*
(
4
/
sizeof
(
T
))))
=
val
;
}
}
}
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
RT
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
copy
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
TileCount
::
reg_h
;
++
i
)
{
int
thread_h_idx
=
gl_load_y
+
i
*
TileCount
::
load_h
;
int
smem_h_idx
=
(
ring_smem_h
+
thread_h_idx
)
%
TileCount
::
smem_h
;
int
src_h_idx
;
if
(
is_fwd
)
{
src_h_idx
=
ring_src_h
+
thread_h_idx
;
}
else
{
src_h_idx
=
start_h
-
smem_h_idx
;
}
if
(
thread_h_idx
>=
TileCount
::
smem_delta_h
)
continue
;
#pragma unroll
for
(
int
j
=
0
;
j
<
TileCount
::
reg_w
;
++
j
)
{
int
smem_w_idx
=
gl_load_x
+
j
*
TileCount
::
load_w
;
int
src_w_idx
;
if
(
is_fwd
)
{
src_w_idx
=
start_w
+
smem_w_idx
;
}
else
{
src_w_idx
=
start_w
+
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
;
}
if
(
TileCount
::
check_bounds_w
&&
smem_w_idx
>=
TileCount
::
smem_w
)
continue
;
T
val
=
0.0
f
;
if
(
src_h_idx
>=
0
&&
src_h_idx
<
bound_h
&&
src_w_idx
>=
0
&&
src_w_idx
<
bound_w
&&
((
is_fwd
&&
src_h_idx
%
stride_h
==
0
&&
src_w_idx
%
stride_w
==
0
)
||
(
!
is_fwd
&&
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
>=
0
)))
{
val
=
g_ptr
[
src_h_idx
/
stride_h
*
stride
+
src_w_idx
/
stride_w
];
}
reg
[
i
][
j
]
=
val
;
}
}
}
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
RT
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
commit
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
TileCount
::
reg_h
;
++
i
)
{
int
thread_h_idx
=
gl_load_y
+
i
*
TileCount
::
load_h
;
int
smem_h_idx
=
(
ring_smem_h
+
thread_h_idx
)
%
TileCount
::
smem_h
;
int
bank_offset
=
smem_h_idx
/
TileCount
::
bank_offset_line
;
if
(
thread_h_idx
>=
TileCount
::
smem_delta_h
)
continue
;
#pragma unroll
for
(
int
j
=
0
;
j
<
TileCount
::
reg_w
;
++
j
)
{
int
smem_w_idx
=
gl_load_x
+
j
*
TileCount
::
load_w
;
if
(
TileCount
::
check_bounds_w
&&
smem_w_idx
>=
TileCount
::
smem_w
)
continue
;
*
(
sh_ptr_as_copy_t
(
smem_h_idx
,
smem_w_idx
+
bank_offset
))
=
reg
[
i
][
j
];
}
}
}
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
RT
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
iter_forward
()
{
if
(
is_fwd
)
{
ring_src_h
+=
TileCount
::
smem_delta_h
;
}
else
{
ring_src_h
-=
TileCount
::
smem_delta_h
;
}
ring_smem_h
=
(
ring_smem_h
+
TileCount
::
smem_delta_h
)
%
TileCount
::
smem_h
;
}
template
<
typename
T
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
struct
Global2SharedMem
<
T
,
uint8_t
,
kDirection
,
ThreadConfig_
,
TileCount_
>
{
using
TileCount
=
TileCount_
;
using
ThreadConfig
=
ThreadConfig_
;
static
const
int
InnerStep
=
sizeof
(
T
);
T
reg
[
TileCount
::
reg_h
][
TileCount
::
reg_w
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
tid
=
tidy
*
ThreadConfig
::
thread_x
+
tidx
;
const
int
gl_load_y
=
tid
/
TileCount
::
load_w
;
const
int
gl_load_x
=
tid
-
gl_load_y
*
TileCount
::
load_w
;
const
bool
is_fwd
=
(
kDirection
==
DIRECTION_FORWARD
);
int
w_offset
;
T
*
smem
;
int
stride
;
int
start_h
,
start_w
,
bound_h
,
bound_w
,
ring_smem_h
,
ring_src_h
;
// just used in backward src data
int
stride_h
,
stride_w
;
const
uint8_t
*
g_ptr
;
__device__
__forceinline__
Global2SharedMem
(
T
*
smem_
,
int
stride_
,
int
s_h
,
int
s_w
,
int
b_h
,
int
b_w
,
int
stride_h_
,
int
stride_w_
);
__device__
__forceinline__
void
first_copy
();
__device__
__forceinline__
void
copy
();
__device__
__forceinline__
void
commit
();
__device__
__forceinline__
void
iter_forward
();
__device__
__forceinline__
T
*
sh_ptr
(
int
y
,
int
x
)
{
return
&
smem
[
y
*
TileCount
::
smem_w
+
x
];
}
__device__
__forceinline__
T
*
sh_ptr_as_copy_t
(
int
y
,
int
x
)
{
return
reinterpret_cast
<
T
*>
(
sh_ptr
(
y
,
x
));
}
};
template
<
typename
T
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
Global2SharedMem
<
T
,
uint8_t
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
Global2SharedMem
(
T
*
smem_
,
int
stride_
,
int
s_h
,
int
s_w
,
int
b_h
,
int
b_w
,
int
stride_h_
,
int
stride_w_
)
:
smem
(
smem_
),
stride
(
stride_
),
start_h
(
s_h
),
start_w
(
s_w
),
bound_h
(
b_h
),
bound_w
(
b_w
),
ring_smem_h
(
TileCount
::
smem_load_h
),
stride_h
(
stride_h_
),
stride_w
(
stride_w_
)
{
if
(
is_fwd
)
{
ring_src_h
=
s_h
+
TileCount
::
smem_load_h
;
w_offset
=
0
;
}
else
{
ring_src_h
=
s_h
-
1
;
w_offset
=
TileCount
::
smem_w
-
b_w
;
// stride_h and stride_w just used in backward src data.
stride_h
=
stride_w
=
1
;
}
}
template
<
typename
T
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
uint8_t
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
first_copy
()
{
static
int
const
load_w
=
TileCount
::
smem_w
>
32
?
32
:
TileCount
::
smem_w
;
static
int
const
load_h
=
ThreadConfig
::
nr_threads
/
load_w
;
static
int
const
h_per_thread
=
DIVUP
(
TileCount
::
smem_load_h
,
load_h
);
static
int
const
w_per_thread
=
DIVUP
(
TileCount
::
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
TileCount
::
smem_load_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
TileCount
::
smem_w
%
load_w
!=
0
;
const
int
y_base_idx
=
tid
/
load_w
;
const
int
x_base_idx
=
tid
-
y_base_idx
*
load_w
;
#pragma unroll
for
(
int
i
=
0
;
i
<
h_per_thread
;
++
i
)
{
int
smem_h_idx
=
y_base_idx
+
i
*
load_h
;
int
bank_offset
=
smem_h_idx
/
TileCount
::
bank_offset_line
;
int
src_h_idx
;
if
(
is_fwd
)
{
src_h_idx
=
start_h
+
smem_h_idx
;
}
else
{
src_h_idx
=
start_h
-
smem_h_idx
;
}
if
(
check_bounds_h
&&
smem_h_idx
>=
TileCount
::
smem_load_h
)
continue
;
#pragma unroll
for
(
int
j
=
0
;
j
<
w_per_thread
;
++
j
)
{
int
smem_w_idx
=
x_base_idx
+
j
*
load_w
;
int
src_w_idx
;
if
(
is_fwd
)
{
src_w_idx
=
start_w
+
smem_w_idx
*
InnerStep
;
}
else
{
src_w_idx
=
start_w
+
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
;
}
if
(
check_bounds_w
&&
smem_w_idx
>=
TileCount
::
smem_w
)
continue
;
T
val
=
0.0
f
;
for
(
int
inner
=
0
;
inner
<
InnerStep
;
inner
++
)
{
T
temp
=
0
;
if
(
src_h_idx
>=
0
&&
src_h_idx
<
bound_h
&&
src_w_idx
>=
0
&&
src_w_idx
<
bound_w
&&
((
is_fwd
&&
src_h_idx
%
stride_h
==
0
&&
src_w_idx
%
stride_w
==
0
)
||
(
!
is_fwd
&&
TileCount
::
smem_load_h
-
smem_h_idx
-
1
>=
0
&&
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
>=
0
)))
{
temp
=
g_ptr
[
src_h_idx
/
stride_h
*
stride
+
src_w_idx
/
stride_w
];
val
|=
(
temp
<<
(
inner
<<
3
));
}
src_w_idx
++
;
}
*
(
sh_ptr_as_copy_t
(
smem_h_idx
,
smem_w_idx
+
bank_offset
*
(
4
/
sizeof
(
T
))))
=
val
;
}
}
}
template
<
typename
T
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
uint8_t
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
copy
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
TileCount
::
reg_h
;
++
i
)
{
int
thread_h_idx
=
gl_load_y
+
i
*
TileCount
::
load_h
;
int
smem_h_idx
=
(
ring_smem_h
+
thread_h_idx
)
%
TileCount
::
smem_h
;
int
src_h_idx
;
if
(
is_fwd
)
{
src_h_idx
=
ring_src_h
+
thread_h_idx
;
}
else
{
src_h_idx
=
start_h
-
smem_h_idx
;
}
if
(
thread_h_idx
>=
TileCount
::
smem_delta_h
)
continue
;
#pragma unroll
for
(
int
j
=
0
;
j
<
TileCount
::
reg_w
;
++
j
)
{
int
smem_w_idx
=
gl_load_x
+
j
*
TileCount
::
load_w
;
int
src_w_idx
;
if
(
is_fwd
)
{
src_w_idx
=
start_w
+
smem_w_idx
*
InnerStep
;
}
else
{
src_w_idx
=
start_w
+
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
;
}
if
(
TileCount
::
check_bounds_w
&&
smem_w_idx
>=
TileCount
::
smem_w
)
continue
;
T
val
=
0.0
f
;
#pragma unroll
for
(
int
inner
=
0
;
inner
<
InnerStep
;
inner
++
)
{
uint32_t
temp
=
0
;
if
(
src_h_idx
>=
0
&&
src_h_idx
<
bound_h
&&
src_w_idx
>=
0
&&
src_w_idx
<
bound_w
&&
((
is_fwd
&&
src_h_idx
%
stride_h
==
0
&&
src_w_idx
%
stride_w
==
0
)
||
(
!
is_fwd
&&
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
>=
0
)))
{
temp
=
g_ptr
[
src_h_idx
/
stride_h
*
stride
+
src_w_idx
/
stride_w
];
val
|=
(
temp
<<
(
inner
<<
3
));
}
src_w_idx
++
;
}
reg
[
i
][
j
]
=
val
;
}
}
}
template
<
typename
T
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
uint8_t
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
commit
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
TileCount
::
reg_h
;
++
i
)
{
int
thread_h_idx
=
gl_load_y
+
i
*
TileCount
::
load_h
;
int
smem_h_idx
=
(
ring_smem_h
+
thread_h_idx
)
%
TileCount
::
smem_h
;
int
bank_offset
=
smem_h_idx
/
TileCount
::
bank_offset_line
;
if
(
thread_h_idx
>=
TileCount
::
smem_delta_h
)
continue
;
#pragma unroll
for
(
int
j
=
0
;
j
<
TileCount
::
reg_w
;
++
j
)
{
int
smem_w_idx
=
gl_load_x
+
j
*
TileCount
::
load_w
;
if
(
TileCount
::
check_bounds_w
&&
smem_w_idx
>=
TileCount
::
smem_w
)
continue
;
*
(
sh_ptr_as_copy_t
(
smem_h_idx
,
smem_w_idx
+
bank_offset
))
=
reg
[
i
][
j
];
}
}
}
template
<
typename
T
,
DepthwiseConv2dDirection
kDirection
,
typename
ThreadConfig_
,
typename
TileCount_
>
__device__
__forceinline__
void
Global2SharedMem
<
T
,
uint8_t
,
kDirection
,
ThreadConfig_
,
TileCount_
>::
iter_forward
()
{
if
(
is_fwd
)
{
ring_src_h
+=
TileCount
::
smem_delta_h
;
}
else
{
ring_src_h
-=
TileCount
::
smem_delta_h
;
}
ring_smem_h
=
(
ring_smem_h
+
TileCount
::
smem_delta_h
)
%
TileCount
::
smem_h
;
}
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
,
int
stride
>
__global__
void
DepthwiseConv2dGPUKernelNCHW
(
const
Param
param
,
const
float
*
input
,
const
float
*
filter
,
const
int
*
rin
,
const
int
*
rout
,
float
*
output
)
{
using
T
=
float
;
using
ThreadConfig
=
typename
ConvTrait
::
ThreadConfig
;
using
SrcTileConfig
=
typename
ConvTrait
::
SrcTileConfig
;
using
FilterTileConfig
=
typename
ConvTrait
::
FilterTileConfig
;
using
OutTileConfig
=
typename
ConvTrait
::
OutTileConfig
;
using
SrcTileCount
=
typename
ConvTrait
::
SrcTileCount
;
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
RinTileCount
=
typename
ConvTrait
::
RinTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
RinGlobal2ShareVisitor
=
typename
ConvTrait
::
RinGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
constexpr
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
extern
__shared__
__align__
(
8
)
unsigned
char
smem
[];
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
*
smem_rin
=
reinterpret_cast
<
int
*>
(
&
smem_flt
[
FilterTileCount
::
smem_size
]);
constexpr
int
stride_h
=
is_fwd
?
stride
:
1
;
constexpr
int
stride_w
=
is_fwd
?
stride
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
batch
=
off_ichannel
/
param
.
src_chl
,
out_start_h
=
off_obh
*
OutTileConfig
::
block_h
,
out_start_w
=
off_obw
*
OutTileConfig
::
block_w
,
src_start_h
=
out_start_h
*
stride_h
-
param
.
pad_h
,
src_start_w
=
out_start_w
*
stride_w
-
param
.
pad_w
,
out_base_h_idx
=
out_start_h
+
off_oh
*
OutTileConfig
::
unroll_h
;
T
*
smem_src_ptr
=
smem_src
+
off_ow
*
FilterTileConfig
::
unroll_w
;
int
*
smem_rin_ptr
=
smem_rin
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
smem_flt_ptr
=
smem_flt
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
out_base_ptr
=
output
+
off_ochannel
*
param
.
out_h
*
param
.
out_w
;
const
int
*
rout_base_ptr
=
rout
+
batch
*
param
.
out_h
*
param
.
out_w
;
int
reg_rout
[
OutTileConfig
::
unroll_size
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
OutTileConfig
::
unroll_h
;
++
i
)
{
int
out_h_idx
=
out_base_h_idx
+
i
;
if
(
out_h_idx
<
param
.
out_h
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
OutTileConfig
::
unroll_w
;
++
j
)
{
int
out_w_idx
=
out_start_w
+
j
;
if
(
out_w_idx
<
param
.
out_w
)
{
reg_rout
[
i
*
OutTileConfig
::
unroll_w
+
j
]
=
rout_base_ptr
[
out_h_idx
*
param
.
out_w
+
out_w_idx
];
}
}
}
}
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
RinGlobal2ShareVisitor
gl2sh_rin
=
{
smem_rin
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
static_cast
<
int
>
(
param
.
flt_w
),
is_fwd
?
0
:
static_cast
<
int
>
(
param
.
flt_h
-
1
),
0
,
static_cast
<
int
>
(
param
.
flt_h
),
static_cast
<
int
>
(
param
.
flt_w
),
1
,
1
};
gl2sh_src
.
g_ptr
=
input
+
off_ichannel
*
param
.
src_h
*
param
.
src_w
;
gl2sh_rin
.
g_ptr
=
rin
+
batch
*
param
.
src_h
*
param
.
src_w
;
gl2sh_flt
.
g_ptr
=
filter
+
off_fchannel
*
param
.
flt_h
*
param
.
flt_w
;
gl2sh_src
.
first_copy
();
gl2sh_rin
.
first_copy
();
gl2sh_flt
.
first_copy
();
__syncthreads
();
T
reg_src
[
2
][
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
int
reg_rin
[
2
][
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
];
T
sum
[
OutTileConfig
::
unroll_size
]
=
{
0.0
};
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
+
(
off_oh
*
stride_h
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
reg_rin
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_rin_ptr
[(
off_oh
*
stride_h
+
s_h
)
%
RinTileCount
::
smem_h
*
RinTileCount
::
smem_w
+
s_w
+
(
off_oh
*
stride_h
+
s_h
)
/
RinTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
f_h
/
FilterTileCount
::
bank_offset_line
];
}
}
int
fh
=
1
;
for
(;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
*
2
)
{
if
(
fh
+
4
<
param
.
flt_h
+
1
)
{
gl2sh_src
.
copy
();
gl2sh_rin
.
copy
();
}
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
int
smem_h_idx
=
(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
;
reg_src
[
1
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[
smem_h_idx
*
SrcTileCount
::
smem_w
+
s_w
+
smem_h_idx
/
SrcTileCount
::
bank_offset_line
];
reg_rin
[
1
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_rin_ptr
[
smem_h_idx
*
RinTileCount
::
smem_w
+
s_w
+
smem_h_idx
/
RinTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
1
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
(
fh
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
int
src_idx
=
(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
;
if
(
reg_rin
[
0
][
src_idx
]
==
reg_rout
[
oh
*
OutTileConfig
::
unroll_w
+
ow
])
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
]
[
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
0
][
src_idx
];
}
}
}
}
}
if
(
fh
+
SrcTileCount
::
smem_delta_h
<
param
.
flt_h
)
{
__syncthreads
();
}
if
(
fh
+
(
SrcTileCount
::
smem_delta_h
<<
1
)
<
param
.
flt_h
)
{
gl2sh_src
.
commit
();
gl2sh_rin
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_rin
.
iter_forward
();
}
if
(
fh
+
1
<
param
.
flt_h
)
{
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
int
smem_h_idx
=
(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
;
reg_src
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[
smem_h_idx
*
SrcTileCount
::
smem_w
+
s_w
+
smem_h_idx
/
SrcTileCount
::
bank_offset_line
];
reg_rin
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_rin_ptr
[
smem_h_idx
*
RinTileCount
::
smem_w
+
s_w
+
smem_h_idx
/
RinTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
1
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
(
fh
+
1
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
int
src_idx
=
(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
;
if
(
reg_rin
[
1
][
src_idx
]
==
reg_rout
[
oh
*
OutTileConfig
::
unroll_w
+
ow
])
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
1
]
[
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
1
][
src_idx
];
}
}
}
}
}
}
if
(
param
.
flt_h
==
fh
)
{
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
int
src_idx
=
(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
;
if
(
reg_rin
[
0
][
src_idx
]
==
reg_rout
[
oh
*
OutTileConfig
::
unroll_w
+
ow
])
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
]
[
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
0
][
src_idx
];
}
}
}
}
}
}
__syncthreads
();
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
sum
[
o
]
+=
__shfl_xor
(
sum
[
o
],
i
,
32
);
}
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
OutTileConfig
::
unroll_h
;
++
i
)
{
int
out_h_idx
=
out_base_h_idx
+
i
;
if
(
out_h_idx
<
param
.
out_h
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
OutTileConfig
::
unroll_w
;
++
j
)
{
int
out_w_idx
=
out_start_w
+
j
;
if
(
out_w_idx
>=
param
.
out_w
)
return
;
out_base_ptr
[
out_h_idx
*
param
.
out_w
+
out_w_idx
]
=
sum
[
i
*
OutTileConfig
::
unroll_w
+
j
];
}
}
}
}
}
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
,
int
stride
>
__global__
void
DepthwiseConv2dGPUKernelNCHW
(
const
Param
param
,
const
float
*
input
,
const
float
*
filter
,
const
uint8_t
*
rin
,
const
uint8_t
*
rout
,
float
*
output
)
{
using
T
=
float
;
using
ThreadConfig
=
typename
ConvTrait
::
ThreadConfig
;
using
SrcTileConfig
=
typename
ConvTrait
::
SrcTileConfig
;
using
FilterTileConfig
=
typename
ConvTrait
::
FilterTileConfig
;
using
OutTileConfig
=
typename
ConvTrait
::
OutTileConfig
;
using
SrcTileCount
=
typename
ConvTrait
::
SrcTileCount
;
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
RinTileCount
=
typename
ConvTrait
::
RinTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
RinGlobal2ShareVisitor
=
typename
ConvTrait
::
RinGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
constexpr
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
extern
__shared__
__align__
(
8
)
unsigned
char
smem
[];
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
*
smem_rin
=
reinterpret_cast
<
int
*>
(
&
smem_flt
[
FilterTileCount
::
smem_size
]);
constexpr
int
stride_h
=
is_fwd
?
stride
:
1
;
constexpr
int
stride_w
=
is_fwd
?
stride
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
batch
=
off_ichannel
/
param
.
src_chl
,
out_start_h
=
off_obh
*
OutTileConfig
::
block_h
,
out_start_w
=
off_obw
*
OutTileConfig
::
block_w
,
src_start_h
=
out_start_h
*
stride_h
-
param
.
pad_h
,
src_start_w
=
out_start_w
*
stride_w
-
param
.
pad_w
,
out_base_h_idx
=
out_start_h
+
off_oh
*
OutTileConfig
::
unroll_h
;
T
*
smem_src_ptr
=
smem_src
+
off_ow
*
FilterTileConfig
::
unroll_w
;
static_assert
((
FilterTileConfig
::
unroll_w
&
3
)
==
0
);
int
*
smem_rin_ptr
=
smem_rin
+
(
off_ow
*
FilterTileConfig
::
unroll_w
>>
2
);
T
*
smem_flt_ptr
=
smem_flt
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
out_base_ptr
=
output
+
off_ochannel
*
param
.
out_h
*
param
.
out_w
;
const
uint8_t
*
rout_base_ptr
=
rout
+
batch
*
param
.
out_h
*
param
.
out_w
;
static_assert
((
OutTileConfig
::
unroll_w
&
3
)
==
0
);
static_assert
((
OutTileConfig
::
block_w
&
3
)
==
0
);
int
reg_rout
[
OutTileConfig
::
unroll_size
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
OutTileConfig
::
unroll_h
;
++
i
)
{
int
out_h_idx
=
out_base_h_idx
+
i
;
if
(
out_h_idx
<
param
.
out_h
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
OutTileConfig
::
unroll_w
;
j
+=
4
)
{
int
out_w_idx
=
out_start_w
+
j
;
if
(
out_w_idx
<
param
.
out_w
)
{
uint32_t
val
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
rout_base_ptr
[
out_h_idx
*
param
.
out_w
+
out_w_idx
]));
reg_rout
[
i
*
OutTileConfig
::
unroll_w
+
j
]
=
val
&
0xff
;
reg_rout
[
i
*
OutTileConfig
::
unroll_w
+
j
+
1
]
=
(
val
>>
8
)
&
0xff
;
reg_rout
[
i
*
OutTileConfig
::
unroll_w
+
j
+
2
]
=
(
val
>>
16
)
&
0xff
;
reg_rout
[
i
*
OutTileConfig
::
unroll_w
+
j
+
3
]
=
(
val
>>
24
)
&
0xff
;
}
}
}
}
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
RinGlobal2ShareVisitor
gl2sh_rin
=
{
smem_rin
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
static_cast
<
int
>
(
param
.
flt_w
),
is_fwd
?
0
:
static_cast
<
int
>
(
param
.
flt_h
-
1
),
0
,
static_cast
<
int
>
(
param
.
flt_h
),
static_cast
<
int
>
(
param
.
flt_w
),
1
,
1
};
gl2sh_src
.
g_ptr
=
input
+
off_ichannel
*
param
.
src_h
*
param
.
src_w
;
gl2sh_rin
.
g_ptr
=
rin
+
batch
*
param
.
src_h
*
param
.
src_w
;
gl2sh_flt
.
g_ptr
=
filter
+
off_fchannel
*
param
.
flt_h
*
param
.
flt_w
;
gl2sh_src
.
first_copy
();
gl2sh_rin
.
first_copy
();
gl2sh_flt
.
first_copy
();
__syncthreads
();
const
static
int
irin_unroll_w
=
(
DIVUP
(
SrcTileConfig
::
unroll_w
,
4
))
<<
2
;
T
reg_src
[
2
][
SrcTileConfig
::
unroll_h
*
irin_unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
int
reg_rin
[
2
][
SrcTileConfig
::
unroll_h
*
irin_unroll_w
];
T
sum
[
OutTileConfig
::
unroll_size
]
=
{
0.0
};
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
int
s_idx
=
(
off_oh
*
stride_h
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
(
off_oh
*
stride_h
+
s_h
)
/
SrcTileCount
::
bank_offset_line
;
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
irin_unroll_w
;
s_w
+=
4
)
{
uint32_t
val
=
smem_rin_ptr
[(
off_oh
*
stride_h
+
s_h
)
%
RinTileCount
::
smem_h
*
RinTileCount
::
smem_w
+
(
s_w
>>
2
)
+
(
off_oh
*
stride_h
+
s_h
)
/
RinTileCount
::
bank_offset_line
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
]
=
smem_src_ptr
[
s_idx
+
s_w
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
1
]
=
smem_src_ptr
[
s_idx
+
s_w
+
1
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
2
]
=
smem_src_ptr
[
s_idx
+
s_w
+
2
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
3
]
=
smem_src_ptr
[
s_idx
+
s_w
+
3
];
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
]
=
val
&
0xff
;
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
1
]
=
(
val
>>
8
)
&
0xff
;
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
2
]
=
(
val
>>
16
)
&
0xff
;
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
3
]
=
(
val
>>
24
)
&
0xff
;
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
f_h
/
FilterTileCount
::
bank_offset_line
];
}
}
int
fh
=
1
;
for
(;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
*
2
)
{
if
(
fh
+
4
<
param
.
flt_h
+
1
)
{
gl2sh_src
.
copy
();
gl2sh_rin
.
copy
();
}
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
int
src_off
=
((
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
)
*
SrcTileCount
::
smem_w
+
((
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
)
/
SrcTileCount
::
bank_offset_line
;
int
rin_h_idx
=
(
off_oh
*
stride_h
+
fh
+
s_h
)
%
RinTileCount
::
smem_h
;
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
irin_unroll_w
;
s_w
+=
4
)
{
uint32_t
val
=
smem_rin_ptr
[
rin_h_idx
*
RinTileCount
::
smem_w
+
(
s_w
>>
2
)
+
rin_h_idx
/
RinTileCount
::
bank_offset_line
];
reg_src
[
1
][
s_h
*
irin_unroll_w
+
s_w
]
=
smem_src_ptr
[
src_off
+
s_w
];
reg_src
[
1
][
s_h
*
irin_unroll_w
+
s_w
+
1
]
=
smem_src_ptr
[
src_off
+
s_w
+
1
];
reg_src
[
1
][
s_h
*
irin_unroll_w
+
s_w
+
2
]
=
smem_src_ptr
[
src_off
+
s_w
+
2
];
reg_src
[
1
][
s_h
*
irin_unroll_w
+
s_w
+
3
]
=
smem_src_ptr
[
src_off
+
s_w
+
3
];
reg_rin
[
1
][
s_h
*
irin_unroll_w
+
s_w
]
=
val
&
0xff
;
reg_rin
[
1
][
s_h
*
irin_unroll_w
+
s_w
+
1
]
=
(
val
>>
8
)
&
0xff
;
reg_rin
[
1
][
s_h
*
irin_unroll_w
+
s_w
+
2
]
=
(
val
>>
16
)
&
0xff
;
reg_rin
[
1
][
s_h
*
irin_unroll_w
+
s_w
+
3
]
=
(
val
>>
24
)
&
0xff
;
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
1
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
(
fh
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
int
src_h_off
=
(
inner_fh
+
oh
)
*
irin_unroll_w
;
int
rin_h_off
=
(
inner_fh
+
oh
)
*
irin_unroll_w
;
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
int
flt_off
=
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
;
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
int
src_w_idx
=
fw
+
ow
*
stride_w
;
if
(
reg_rin
[
0
][
rin_h_off
+
src_w_idx
]
==
reg_rout
[
oh
*
OutTileConfig
::
unroll_w
+
ow
])
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
][
flt_off
]
*
reg_src
[
0
][
src_h_off
+
src_w_idx
];
}
}
}
}
}
if
(
fh
+
SrcTileCount
::
smem_delta_h
<
param
.
flt_h
)
{
__syncthreads
();
}
if
(
fh
+
(
SrcTileCount
::
smem_delta_h
<<
1
)
<
param
.
flt_h
)
{
gl2sh_src
.
commit
();
gl2sh_rin
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_rin
.
iter_forward
();
}
if
(
fh
+
1
<
param
.
flt_h
)
{
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
int
src_idx
=
((
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
)
*
SrcTileCount
::
smem_w
+
((
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
)
/
SrcTileCount
::
bank_offset_line
;
int
rin_h_idx
=
(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
RinTileCount
::
smem_h
;
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
irin_unroll_w
;
s_w
+=
4
)
{
uint32_t
val
=
smem_rin_ptr
[
rin_h_idx
*
RinTileCount
::
smem_w
+
(
s_w
>>
2
)
+
rin_h_idx
/
RinTileCount
::
bank_offset_line
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
]
=
smem_src_ptr
[
src_idx
+
s_w
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
1
]
=
smem_src_ptr
[
src_idx
+
s_w
+
1
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
2
]
=
smem_src_ptr
[
src_idx
+
s_w
+
2
];
reg_src
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
3
]
=
smem_src_ptr
[
src_idx
+
s_w
+
3
];
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
]
=
val
&
0xff
;
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
1
]
=
(
val
>>
8
)
&
0xff
;
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
2
]
=
(
val
>>
16
)
&
0xff
;
reg_rin
[
0
][
s_h
*
irin_unroll_w
+
s_w
+
3
]
=
(
val
>>
24
)
&
0xff
;
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
1
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
(
fh
+
1
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
int
src_h_off
=
(
inner_fh
+
oh
)
*
irin_unroll_w
;
int
rin_h_off
=
(
inner_fh
+
oh
)
*
irin_unroll_w
;
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
int
flt_off
=
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
;
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
int
src_w_idx
=
fw
+
ow
*
stride_w
;
if
(
reg_rin
[
1
][
rin_h_off
+
src_w_idx
]
==
reg_rout
[
oh
*
OutTileConfig
::
unroll_w
+
ow
])
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
1
][
flt_off
]
*
reg_src
[
1
][
src_h_off
+
src_w_idx
];
}
}
}
}
}
}
if
(
param
.
flt_h
==
fh
)
{
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
int
src_h_off
=
(
inner_fh
+
oh
)
*
irin_unroll_w
;
int
rin_h_off
=
(
inner_fh
+
oh
)
*
irin_unroll_w
;
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
int
flt_off
=
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
;
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
int
src_w_idx
=
fw
+
ow
*
stride_w
;
if
(
reg_rin
[
0
][
rin_h_off
+
src_w_idx
]
==
reg_rout
[
oh
*
OutTileConfig
::
unroll_w
+
ow
])
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
][
flt_off
]
*
reg_src
[
0
][
src_h_off
+
src_w_idx
];
}
}
}
}
}
}
__syncthreads
();
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
sum
[
o
]
+=
__shfl_xor
(
sum
[
o
],
i
,
32
);
}
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
OutTileConfig
::
unroll_h
;
++
i
)
{
int
out_h_idx
=
out_base_h_idx
+
i
;
if
(
out_h_idx
<
param
.
out_h
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
OutTileConfig
::
unroll_w
;
++
j
)
{
int
out_w_idx
=
out_start_w
+
j
;
if
(
out_w_idx
>=
param
.
out_w
)
return
;
out_base_ptr
[
out_h_idx
*
param
.
out_w
+
out_w_idx
]
=
sum
[
i
*
OutTileConfig
::
unroll_w
+
j
];
}
}
}
}
}
template
<
typename
T
,
typename
RT
,
DepthwiseConv2dDirection
kDirection
,
int
unroll_fw
,
int
unroll_ow
,
int
stride
>
void
LaunchDepthwiseConv2dGPU
(
const
Param
&
param
,
const
T
*
input
,
const
T
*
filter
,
const
RT
*
rin
,
const
RT
*
rout
,
T
*
output
,
cudaStream_t
stream
)
{
static
int
const
unroll_oh
=
1
,
unroll_fh
=
1
;
using
FilterTileConfig
=
FilterTileConfig
<
unroll_fh
,
unroll_fw
>
;
using
ThreadConfig
=
ThreadConfig
<
4
,
32
>
;
using
OutTileConfig
=
OutTileConfig
<
ThreadConfig
,
unroll_oh
,
unroll_ow
>
;
using
IConvTrait
=
ConvTrait
<
T
,
int
,
RT
,
kDirection
,
ThreadConfig
,
OutTileConfig
,
FilterTileConfig
,
stride
,
stride
>
;
using
SrcTileCount
=
typename
IConvTrait
::
SrcTileCount
;
using
FilterTileCount
=
typename
IConvTrait
::
FilterTileCount
;
using
RinTileCount
=
typename
IConvTrait
::
RinTileCount
;
dim3
block
(
ThreadConfig
::
thread_x
,
ThreadConfig
::
thread_y
);
dim3
grid
;
grid
.
x
=
param
.
batch
*
param
.
src_chl
*
param
.
chl_mul
;
grid
.
y
=
DIVUP
(
param
.
out_w
,
OutTileConfig
::
block_w
);
grid
.
z
=
DIVUP
(
param
.
out_h
,
OutTileConfig
::
block_h
);
const
int
shared_storage
=
(
SrcTileCount
::
smem_size
+
FilterTileCount
::
smem_size
)
*
sizeof
(
T
)
+
RinTileCount
::
smem_size
*
sizeof
(
int
);
void
(
*
kernel
)(
const
Param
,
const
T
*
,
const
T
*
,
const
RT
*
,
const
RT
*
,
T
*
);
if
(
param
.
is_compute_deafult
)
{
kernel
=
DepthwiseConv2dGPUKernelNCHW
<
IConvTrait
,
kDirection
,
stride
>
;
}
else
{
megdnn_assert_internal
(
0
);
}
kernel
<<<
grid
,
block
,
shared_storage
,
stream
>>>
(
param
,
input
,
filter
,
rin
,
rout
,
output
);
after_kernel_launch
();
}
#define INSTANCE_AB(type1, type2, a, direction) \
if (param.out_w > 28) { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 8, 1>( \
param, src, flt, rin, rout, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 8, 2>( \
param, src, flt, rin, rout, dst, stream); \
} \
} else { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 4, 1>( \
param, src, flt, rin, rout, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a, 4, 2>( \
param, src, flt, rin, rout, dst, stream); \
} \
}
#define INSTANCE_INT(type1, type2, direction) \
if (param.flt_w > 24) { \
INSTANCE_AB(type1, type2, 8, direction) \
} else if (param.flt_w > 16) { \
INSTANCE_AB(type1, type2, 6, direction) \
} else if (param.flt_w > 8) { \
INSTANCE_AB(type1, type2, 4, direction) \
} else { \
INSTANCE_AB(type1, type2, 2, direction) \
}
#define INSTANCE_UINT8(type1, type2, direction) \
if (param.flt_w > 16) { \
INSTANCE_AB(type1, type2, 8, direction) \
} else { \
INSTANCE_AB(type1, type2, 4, direction) \
}
}
// anonymous namespace
dnn/src/cuda/region_restricted_convolution/chanwise/fwd_large_filter.cu
0 → 100644
浏览文件 @
543c9b77
#include "cuda.h"
#include "cuda_fp16.h"
#include "src/cuda/fp16_help.cuh"
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
region_restricted_convolution
;
using
namespace
chanwise
;
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter_algo.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
region_restricted_convolution
{
namespace
chanwise
{
// =====================================fwd=====================================
#define check
template
<
>
void
run_fwd_depthwise_large_filter
(
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
int
*
rin
,
const
int
*
rout
,
const
Param
&
param
,
cudaStream_t
stream
)
{
INSTANCE_INT
(
float
,
int
,
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
)
}
template
<
>
void
run_fwd_depthwise_large_filter
(
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
uint8_t
*
rin
,
const
uint8_t
*
rout
,
const
Param
&
param
,
cudaStream_t
stream
)
{
INSTANCE_UINT8
(
float
,
uint8_t
,
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
)
}
}
// namespace chanwise
}
// namespace region_restricted_convolution
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cuda.doxygen
dnn/src/cuda/region_restricted_convolution/chanwise/kern.cuh
0 → 100644
浏览文件 @
543c9b77
#pragma once
#include <cuda_runtime.h>
#include <stdint.h>
#include "src/cuda/utils.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
region_restricted_convolution
{
namespace
chanwise
{
struct
Param
{
int
batch
,
src_chl
,
src_h
,
src_w
,
chl_mul
,
flt_h
,
flt_w
,
out_h
,
out_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
;
bool
is_compute_deafult
;
#if MEGDNN_CC_HOST
static
Param
load
(
const
TensorShape
&
src
,
const
TensorShape
&
dst
,
const
RegionRestrictedConvolutionForward
::
CanonizedFilterMeta
&
fm
,
bool
is_compute_deafult_
=
true
)
{
#define U(v) static_cast<int>(v)
size_t
c_pos
,
hw_pos
;
if
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW
)
{
c_pos
=
1
;
hw_pos
=
2
;
}
else
{
megdnn_assert_internal
(
0
);
}
return
{
U
(
src
[
0
]),
U
(
src
[
c_pos
]),
U
(
src
[
hw_pos
]),
U
(
src
[
hw_pos
+
1
]),
U
(
fm
.
ocpg
),
U
(
fm
.
spatial
[
0
]),
U
(
fm
.
spatial
[
1
]),
U
(
dst
[
hw_pos
]),
U
(
dst
[
hw_pos
+
1
]),
U
(
fm
.
padding
[
0
]),
U
(
fm
.
padding
[
1
]),
U
(
fm
.
stride
[
0
]),
U
(
fm
.
stride
[
1
]),
U
(
fm
.
dilation
[
0
]),
U
(
fm
.
dilation
[
1
]),
is_compute_deafult_
,
};
#undef U
}
#endif
};
template
<
typename
T
,
typename
RT
>
void
run_fwd_depthwise_large_filter
(
T
*
dst
,
const
T
*
src
,
const
T
*
flt
,
const
RT
*
rin
,
const
RT
*
rout
,
const
Param
&
param
,
cudaStream_t
stream
);
template
<
typename
T
,
typename
RT
>
void
run_bwd_depthwise_large_filter
(
T
*
dst
,
const
T
*
src
,
const
T
*
flt
,
const
RT
*
rin
,
const
RT
*
rout
,
const
Param
&
param
,
cudaStream_t
stream
);
}
// namespace chanwise
}
// namespace region_restricted_convolution
}
// namespace cuda
}
// namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen
dnn/src/cuda/region_restricted_convolution/opr_impl.cpp
0 → 100644
浏览文件 @
543c9b77
#include "src/cuda/region_restricted_convolution/opr_impl.h"
#include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh"
#include "src/cuda/region_restricted_convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
region_restricted_convolution
;
/* ============== RegionRestrictedConvolutionForwardImpl ============== */
void
RegionRestrictedConvolutionForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
auto
fm
=
check_exec
(
src
.
layout
,
filter
.
layout
,
rin
.
layout
,
rout
.
layout
,
dst
.
layout
,
workspace
.
size
);
auto
kparam
=
chanwise
::
Param
::
load
(
src
.
layout
,
dst
.
layout
,
fm
,
param
().
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
);
megdnn_assert
(
fm
.
group
>
1
&&
src
.
layout
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
&&
param
().
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
fm
.
spatial_ndim
==
2
&&
fm
.
icpg
==
1
&&
fm
.
ocpg
==
1
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
!
fm
.
should_flip
&&
param
().
stride_h
==
1
&&
param
().
stride_w
==
1
);
if
(
rin
.
layout
.
dtype
==
dtype
::
Uint8
())
{
megdnn_assert
((
src
.
layout
.
shape
[
3
]
&
3
)
==
0
&&
(
dst
.
layout
.
shape
[
3
]
&
3
)
==
0
);
}
auto
stream
=
cuda_stream
(
handle
());
if
(
filter
.
layout
.
dtype
==
dtype
::
Float32
()
&&
rin
.
layout
.
dtype
==
dtype
::
Int32
()
&&
rout
.
layout
.
dtype
==
dtype
::
Int32
())
{
chanwise
::
run_fwd_depthwise_large_filter
(
dst
.
ptr
<
float
>
(),
src
.
ptr
<
float
>
(),
filter
.
ptr
<
float
>
(),
rin
.
ptr
<
int
>
(),
rout
.
ptr
<
int
>
(),
kparam
,
stream
);
}
else
if
(
filter
.
layout
.
dtype
==
dtype
::
Float32
()
&&
rin
.
layout
.
dtype
==
dtype
::
Uint8
()
&&
rout
.
layout
.
dtype
==
dtype
::
Uint8
())
{
chanwise
::
run_fwd_depthwise_large_filter
(
dst
.
ptr
<
float
>
(),
src
.
ptr
<
float
>
(),
filter
.
ptr
<
float
>
(),
rin
.
ptr
<
uint8_t
>
(),
rout
.
ptr
<
uint8_t
>
(),
kparam
,
stream
);
}
else
{
megdnn_assert_internal
(
0
);
}
}
size_t
RegionRestrictedConvolutionBackwardDataImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
filter
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
rin
,
const
TensorLayout
&
rout
,
const
TensorLayout
&
grad
)
{
return
0
;
}
/* ============== RegionRestrictedConvolutionBackwardDataImpl ============== */
void
RegionRestrictedConvolutionBackwardDataImpl
::
exec
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
megdnn_throw
(
ssprintf
(
"unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s"
,
filter
.
layout
.
dtype
.
name
(),
diff
.
layout
.
dtype
.
name
(),
rin
.
layout
.
dtype
.
name
(),
rout
.
layout
.
dtype
.
name
(),
grad
.
layout
.
dtype
.
name
()));
}
size_t
RegionRestrictedConvolutionBackwardFilterImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
grad
)
{
size_t
workspace_size
=
0
;
return
workspace_size
;
}
/* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */
void
RegionRestrictedConvolutionBackwardFilterImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
{
megdnn_assert_internal
(
0
);
}
// vim: syntax=cpp.doxygen
dnn/src/cuda/region_restricted_convolution/opr_impl.h
0 → 100644
浏览文件 @
543c9b77
#pragma once
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
namespace
megdnn
{
namespace
cuda
{
class
RegionRestrictedConvolutionForwardImpl
:
public
RegionRestrictedConvolutionForward
{
public:
using
RegionRestrictedConvolutionForward
::
RegionRestrictedConvolutionForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
};
class
RegionRestrictedConvolutionBackwardDataImpl
:
public
RegionRestrictedConvolutionBackwardData
{
public:
using
RegionRestrictedConvolutionBackwardData
::
RegionRestrictedConvolutionBackwardData
;
void
exec
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
;
};
class
RegionRestrictedConvolutionBackwardFilterImpl
:
public
RegionRestrictedConvolutionBackwardFilter
{
public:
using
RegionRestrictedConvolutionBackwardFilter
::
RegionRestrictedConvolutionBackwardFilter
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
grad
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
;
};
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/naive/convolution/helper.h
浏览文件 @
543c9b77
...
...
@@ -878,8 +878,9 @@ void forward_bias(
}
template
<
typename
stype
,
typename
ftype
,
typename
dtype
,
typename
comp_type
,
class
Strategy
,
typename
FilterMeta
,
typename
FilterVisitor
=
ConvFilterVisitor
>
typename
stype
,
typename
ftype
,
typename
rtype
,
typename
dtype
,
typename
comp_type
,
class
Strategy
,
typename
FilterMeta
,
typename
FilterVisitor
=
ConvFilterVisitor
>
void
region_restricted_compute
(
_megdnn_tensor_in
src
,
ftype
*
__restrict
fptr
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
dst
,
const
FilterMeta
&
filter_meta
)
{
...
...
@@ -897,8 +898,8 @@ void region_restricted_compute(
int
dh
=
filter_meta
.
dilation
[
0
],
dw
=
filter_meta
.
dilation
[
1
];
stype
*
__restrict
sptr
=
src
.
compatible_ptr
<
stype
>
();
dtype
*
__restrict
dptr
=
dst
.
compatible_ptr
<
dtype
>
();
int32_t
*
__restrict
rinptr
=
rin
.
ptr
<
int32_t
>
();
int32_t
*
__restrict
routptr
=
rout
.
ptr
<
int32_t
>
();
rtype
*
__restrict
rinptr
=
rin
.
compatible_ptr
<
rtype
>
();
rtype
*
__restrict
routptr
=
rout
.
compatible_ptr
<
rtype
>
();
int
h_offset
=
-
ph
,
w_offset
=
-
pw
;
if
(
filter_meta
.
should_flip
)
{
...
...
@@ -934,7 +935,7 @@ void region_restricted_compute(
ftype
*
fptr_cur
=
FilterVisitor
::
template
get_current_ptr
(
fptr
,
n
,
oc
,
oh
,
ow
,
filter_sizes
);
Strategy
::
init_dval
(
dval
);
int32_t
routval
=
routptr
[
get_region_addr
(
n
,
oh
,
ow
,
rout
.
layout
)];
rtype
&
routval
=
routptr
[
get_region_addr
(
n
,
oh
,
ow
,
rout
.
layout
)];
for
(
size_t
fh
=
0
;
fh
<
FH
;
++
fh
)
for
(
size_t
fw
=
0
;
fw
<
FW
;
++
fw
)
{
...
...
@@ -950,7 +951,7 @@ void region_restricted_compute(
n
,
ic
,
ih
,
iw
,
src
.
layout
)];
ftype
&
fval
=
fptr_cur
[
get_filter_addr
(
gc_out
,
ic
,
ic0
,
fh
,
fw
)];
int32_t
rinval
=
rinptr
[
get_region_addr
(
rtype
&
rinval
=
rinptr
[
get_region_addr
(
n
,
ih
,
iw
,
rin
.
layout
)];
if
(
routval
==
rinval
)
{
Strategy
::
on
(
...
...
@@ -967,28 +968,32 @@ void region_restricted_compute(
}
//! forward with only filter ptr
template
<
typename
stype
,
typename
ftype
,
typename
dtype
,
typename
comp_type
>
template
<
typename
stype
,
typename
ftype
,
typename
rtype
,
typename
dtype
,
typename
comp_type
>
void
region_restricted_forward
(
_megdnn_tensor_in
src
,
const
ftype
*
fptr
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
dst
,
const
RegionRestrictedConvolution
::
CanonizedFilterMeta
&
filter_meta
)
{
megdnn_assert
(
filter_meta
.
spatial_ndim
==
2
);
megdnn_assert
(
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW
);
region_restricted_compute
<
stype
,
ftype
,
dtype
,
comp_type
,
StrategyFwd
>
(
region_restricted_compute
<
stype
,
ftype
,
rtype
,
dtype
,
comp_type
,
StrategyFwd
>
(
src
,
const_cast
<
ftype
*>
(
fptr
),
rin
,
rout
,
dst
,
filter_meta
);
}
//! forward with full filter (for API compatibility)
template
<
typename
stype
,
typename
ftype
,
typename
dtype
,
typename
comp_type
>
template
<
typename
stype
,
typename
ftype
,
typename
rtype
,
typename
dtype
,
typename
comp_type
>
void
region_restricted_forward
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
dst
,
const
RegionRestrictedConvolution
::
CanonizedFilterMeta
&
filter_meta
)
{
return
region_restricted_forward
<
stype
,
ftype
,
dtype
,
comp_type
>
(
return
region_restricted_forward
<
stype
,
ftype
,
rtype
,
dtype
,
comp_type
>
(
src
,
filter
.
compatible_ptr
<
ftype
>
(),
rin
,
rout
,
dst
,
filter_meta
);
}
template
<
typename
ftype
,
typename
dtype
,
typename
gtype
>
template
<
typename
ftype
,
typename
dtype
,
typename
rtype
,
typename
gtype
>
void
region_restricted_backward_data
(
_megdnn_tensor_in
filter
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
grad
,
...
...
@@ -996,11 +1001,11 @@ void region_restricted_backward_data(
megdnn_assert
(
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW
);
memset
(
grad
.
raw_ptr
(),
0
,
grad
.
layout
.
span
().
dist_byte
());
megdnn_assert
(
filter_meta
.
spatial_ndim
==
2
);
region_restricted_compute
<
gtype
,
ftype
,
dtype
,
dtype
,
StrategyBwdData
>
(
region_restricted_compute
<
gtype
,
ftype
,
rtype
,
dtype
,
dtype
,
StrategyBwdData
>
(
grad
,
filter
.
compatible_ptr
<
ftype
>
(),
rin
,
rout
,
diff
,
filter_meta
);
}
template
<
typename
stype
,
typename
dtype
,
typename
gtype
>
template
<
typename
stype
,
typename
dtype
,
typename
rtype
,
typename
gtype
>
void
region_restricted_backward_filter
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
diff
,
_megdnn_tensor_in
rin
,
_megdnn_tensor_in
rout
,
_megdnn_tensor_out
grad
,
...
...
@@ -1008,7 +1013,7 @@ void region_restricted_backward_filter(
megdnn_assert
(
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW
);
memset
(
grad
.
raw_ptr
(),
0
,
grad
.
layout
.
span
().
dist_byte
());
megdnn_assert
(
filter_meta
.
spatial_ndim
==
2
);
region_restricted_compute
<
stype
,
gtype
,
dtype
,
dtype
,
StrategyBwdFlt
>
(
region_restricted_compute
<
stype
,
gtype
,
rtype
,
dtype
,
dtype
,
StrategyBwdFlt
>
(
src
,
grad
.
compatible_ptr
<
gtype
>
(),
rin
,
rout
,
diff
,
filter_meta
);
}
...
...
dnn/src/naive/region_restricted_convolution/opr_impl.cpp
浏览文件 @
543c9b77
...
...
@@ -22,28 +22,37 @@ void RegionRestrictedConvolutionForwardImpl::exec(
src
.
layout
,
filter
.
layout
,
rin
.
layout
,
rout
.
layout
,
dst
.
layout
,
workspace
.
size
);
using
ComputeMode
=
Param
::
ComputeMode
;
#define DISPATCH_CMODE(in_dt,
out_dt, in_ct, out_ct, comp_ct, cmode)
\
#define DISPATCH_CMODE(in_dt,
r_dt, out_dt, in_ct, r_ct, out_ct, comp_ct, cmode)
\
do { \
using namespace dtype; \
if (src.layout.dtype.enumv() == DTypeTrait<in_dt>::enumv && \
dst.layout.dtype.enumv() == DTypeTrait<out_dt>::enumv && \
rin.layout.dtype.enumv() == DTypeTrait<r_dt>::enumv && \
rout.layout.dtype.enumv() == DTypeTrait<r_dt>::enumv && \
param().compute_mode == cmode) { \
MEGDNN_DISPATCH_CPU_KERN_OPR((convolution::region_restricted_forward< \
in_ct, in_ct,
out_ct, comp_ct>(
\
in_ct, in_ct,
r_ct, out_ct, comp_ct>(
\
src, filter, rin, rout, dst, filter_meta));); \
return; \
} \
} while (0);
#define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \
DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT)
#define DISPATCH(in_dt, r_dt, out_dt, in_ct, r_ct, out_ct, comp_ct) \
DISPATCH_CMODE( \
in_dt, r_dt, out_dt, in_ct, r_ct, out_ct, comp_ct, ComputeMode::DEFAULT)
#define cb(dt) \
DISPATCH( \
dt, dt, DTypeTrait<dt>::ctype, DTypeTrait<dt>::ctype, \
dt, Int32, dt, DTypeTrait<dt>::ctype, dt_int32, DTypeTrait<dt>::ctype, \
DTypeTrait<dt>::ctype) \
DISPATCH( \
dt, Uint8, dt, DTypeTrait<dt>::ctype, dt_uint8, DTypeTrait<dt>::ctype, \
DTypeTrait<dt>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
#undef cb
DNN_INC_FLOAT16
(
DISPATCH_CMODE
(
Float16
,
Float16
,
dt_float16
,
dt_float16
,
dt_float32
,
Float16
,
Int32
,
Float16
,
dt_float16
,
dt_int32
,
dt_float16
,
dt_float32
,
ComputeMode
::
FLOAT32
));
DNN_INC_FLOAT16
(
DISPATCH_CMODE
(
Float16
,
Uint8
,
Float16
,
dt_float16
,
dt_uint8
,
dt_float16
,
dt_float32
,
ComputeMode
::
FLOAT32
));
#undef DISPATCH
megdnn_throw
(
ssprintf
(
...
...
@@ -89,26 +98,51 @@ void RegionRestrictedConvolutionBackwardDataImpl::exec(
auto
cmode
=
param
().
compute_mode
;
#define cb(dt) \
do { \
if (filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \
if (filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Int32() && \
rout.layout.dtype == dtype::Int32()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
(convolution::region_restricted_backward_data< \
ctype, ctype, dt_int32, ctype>( \
filter, diff, rin, rout, grad, filter_meta))); \
return; \
} else if ( \
filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Uint8() && \
rout.layout.dtype == dtype::Uint8()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
(convolution::region_restricted_backward_data< \
ctype, ctype,
ctype>(
\
filter, diff, rin, rout, grad, filter_meta))
;);
\
ctype, ctype,
dt_uint8, ctype>(
\
filter, diff, rin, rout, grad, filter_meta))
);
\
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
#undef cb
#if !MEGDNN_DISABLE_FLOAT16
if
(
filter
.
layout
.
dtype
==
dtype
::
Float16
()
&&
cmode
==
ComputeMode
::
FLOAT32
)
{
if
(
filter
.
layout
.
dtype
==
dtype
::
Float16
()
&&
cmode
==
ComputeMode
::
FLOAT32
&&
rin
.
layout
.
dtype
==
dtype
::
Int32
()
&&
rout
.
layout
.
dtype
==
dtype
::
Int32
())
{
TensorND
grad_fp32
{
workspace
.
raw_ptr
,
TensorLayout
{
grad
.
layout
,
dtype
::
Float32
()}};
auto
&&
type_cvt
=
handle
()
->
create_operator
<
TypeCvt
>
();
type_cvt
->
exec
(
grad
,
grad_fp32
);
MEGDNN_DISPATCH_CPU_KERN_OPR
((
convolution
::
region_restricted_backward_data
<
dt_float16
,
dt_float16
,
dt_int32
,
dt_float32
>
(
filter
,
diff
,
rin
,
rout
,
grad_fp32
,
filter_meta
)));
type_cvt
->
exec
(
grad_fp32
,
grad
);
return
;
}
else
if
(
filter
.
layout
.
dtype
==
dtype
::
Float16
()
&&
cmode
==
ComputeMode
::
FLOAT32
&&
rin
.
layout
.
dtype
==
dtype
::
Uint8
()
&&
rout
.
layout
.
dtype
==
dtype
::
Uint8
())
{
TensorND
grad_fp32
{
workspace
.
raw_ptr
,
TensorLayout
{
grad
.
layout
,
dtype
::
Float32
()}};
auto
&&
type_cvt
=
handle
()
->
create_operator
<
TypeCvt
>
();
type_cvt
->
exec
(
grad
,
grad_fp32
);
MEGDNN_DISPATCH_CPU_KERN_OPR
((
convolution
::
region_restricted_backward_data
<
dt_float16
,
dt_float16
,
dt_float32
>
(
filter
,
diff
,
rin
,
rout
,
grad_fp32
,
filter_meta
))
;
);
dt_float16
,
dt_float16
,
dt_
uint8
,
dt_
float32
>
(
filter
,
diff
,
rin
,
rout
,
grad_fp32
,
filter_meta
)));
type_cvt
->
exec
(
grad_fp32
,
grad
);
return
;
}
...
...
@@ -148,12 +182,27 @@ void RegionRestrictedConvolutionBackwardFilterImpl::exec(
auto
cmode
=
param
().
compute_mode
;
#define cb(dt) \
do { \
if (src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \
if (src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Int32() && \
rout.layout.dtype == dtype::Int32()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution::region_restricted_backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA dt_int32 \
MEGDNN_COMMA ctype>( \
src, diff, rin, rout, grad, filter_meta);); \
return; \
} else if ( \
src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT && \
rin.layout.dtype == dtype::Uint8() && \
rout.layout.dtype == dtype::Uint8()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution::region_restricted_backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA dt_uint8 \
MEGDNN_COMMA ctype>( \
src, diff, rin, rout, grad, filter_meta);); \
return; \
} \
...
...
@@ -161,13 +210,26 @@ void RegionRestrictedConvolutionBackwardFilterImpl::exec(
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
cb
);
#undef cb
#if !MEGDNN_DISABLE_FLOAT16
if
(
src
.
layout
.
dtype
==
dtype
::
Float16
()
&&
cmode
==
ComputeMode
::
FLOAT32
)
{
if
(
src
.
layout
.
dtype
==
dtype
::
Float16
()
&&
cmode
==
ComputeMode
::
FLOAT32
&&
rin
.
layout
.
dtype
==
dtype
::
Int32
()
&&
rout
.
layout
.
dtype
==
dtype
::
Int32
())
{
TensorND
grad_fp32
{
workspace
.
raw_ptr
,
TensorLayout
{
grad
.
layout
,
dtype
::
Float32
()}};
auto
&&
type_cvt
=
handle
()
->
create_operator
<
TypeCvt
>
();
type_cvt
->
exec
(
grad
,
grad_fp32
);
MEGDNN_DISPATCH_CPU_KERN_OPR
((
convolution
::
region_restricted_backward_filter
<
dt_float16
,
dt_float16
,
dt_int32
,
dt_float32
>
(
src
,
diff
,
rin
,
rout
,
grad_fp32
,
filter_meta
)););
type_cvt
->
exec
(
grad_fp32
,
grad
);
return
;
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Float16
()
&&
cmode
==
ComputeMode
::
FLOAT32
&&
rin
.
layout
.
dtype
==
dtype
::
Uint8
()
&&
rout
.
layout
.
dtype
==
dtype
::
Uint8
())
{
TensorND
grad_fp32
{
workspace
.
raw_ptr
,
TensorLayout
{
grad
.
layout
,
dtype
::
Float32
()}};
auto
&&
type_cvt
=
handle
()
->
create_operator
<
TypeCvt
>
();
type_cvt
->
exec
(
grad
,
grad_fp32
);
MEGDNN_DISPATCH_CPU_KERN_OPR
((
convolution
::
region_restricted_backward_filter
<
dt_float16
,
dt_float16
,
dt_float32
>
(
dt_float16
,
dt_float16
,
dt_
uint8
,
dt_
float32
>
(
src
,
diff
,
rin
,
rout
,
grad_fp32
,
filter_meta
)););
type_cvt
->
exec
(
grad_fp32
,
grad
);
return
;
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
543c9b77
...
...
@@ -717,11 +717,11 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward
::
algo_name
<
ConvBias
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
#if CUDA_VERSION >= 9000
dtype
::
Float16
()
#endif
//
#if CUDA_VERSION >= 9000
//
dtype::Float16()
//
#endif
})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
...
...
@@ -750,36 +750,36 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
checker
.
set_param
(
cur_param
).
execs
(
{{
n
,
g
,
h
,
h
},
{
g
,
1
,
1
,
fh
,
fh
},
{},
{},
{}});
};
run
(
4
,
8
,
32
,
5
,
5
/
2
,
1
);
run
(
4
,
8
,
32
,
7
,
7
/
2
,
1
);
run
(
4
,
8
,
32
,
9
,
9
/
2
,
1
);
run
(
4
,
8
,
32
,
11
,
11
/
2
,
1
);
run
(
4
,
8
,
32
,
13
,
13
/
2
,
1
);
run
(
4
,
8
,
32
,
15
,
15
/
2
,
1
);
run
(
4
,
8
,
32
,
17
,
17
/
2
,
1
);
run
(
4
,
8
,
32
,
19
,
19
/
2
,
1
);
run
(
4
,
8
,
32
,
21
,
21
/
2
,
1
);
run
(
4
,
8
,
32
,
23
,
23
/
2
,
1
);
run
(
4
,
8
,
32
,
25
,
25
/
2
,
1
);
run
(
4
,
8
,
32
,
27
,
27
/
2
,
1
);
run
(
4
,
8
,
32
,
29
,
29
/
2
,
1
);
run
(
4
,
8
,
32
,
31
,
31
/
2
,
1
);
run
(
4
,
8
,
64
,
5
,
5
/
3
,
2
);
run
(
4
,
8
,
64
,
7
,
7
/
3
,
2
);
run
(
4
,
8
,
64
,
9
,
9
/
3
,
2
);
run
(
4
,
8
,
64
,
11
,
11
/
3
,
2
);
run
(
4
,
8
,
64
,
13
,
13
/
3
,
2
);
run
(
4
,
8
,
64
,
15
,
15
/
3
,
2
);
run
(
4
,
8
,
64
,
17
,
17
/
3
,
2
);
run
(
4
,
8
,
64
,
19
,
19
/
3
,
2
);
run
(
4
,
8
,
64
,
21
,
21
/
3
,
2
);
run
(
4
,
8
,
64
,
23
,
23
/
3
,
2
);
run
(
4
,
8
,
64
,
25
,
25
/
3
,
2
);
run
(
4
,
8
,
64
,
27
,
27
/
3
,
2
);
run
(
4
,
8
,
64
,
29
,
29
/
3
,
2
);
run
(
4
,
8
,
64
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
128
,
31
,
10
,
2
);
run
(
1
,
2
,
256
,
31
,
10
,
2
);
//
run(4, 8, 32, 5, 5 / 2, 1);
//
run(4, 8, 32, 7, 7 / 2, 1);
//
run(4, 8, 32, 9, 9 / 2, 1);
//
run(4, 8, 32, 11, 11 / 2, 1);
//
run(4, 8, 32, 13, 13 / 2, 1);
//
run(4, 8, 32, 15, 15 / 2, 1);
//
run(4, 8, 32, 17, 17 / 2, 1);
//
run(4, 8, 32, 19, 19 / 2, 1);
//
run(4, 8, 32, 21, 21 / 2, 1);
//
run(4, 8, 32, 23, 23 / 2, 1);
//
run(4, 8, 32, 25, 25 / 2, 1);
//
run(4, 8, 32, 27, 27 / 2, 1);
//
run(4, 8, 32, 29, 29 / 2, 1);
run
(
64
,
384
,
32
,
31
,
31
/
2
,
1
);
//
run(4, 8, 64, 5, 5 / 3, 2);
//
run(4, 8, 64, 7, 7 / 3, 2);
//
run(4, 8, 64, 9, 9 / 3, 2);
//
run(4, 8, 64, 11, 11 / 3, 2);
//
run(4, 8, 64, 13, 13 / 3, 2);
//
run(4, 8, 64, 15, 15 / 3, 2);
//
run(4, 8, 64, 17, 17 / 3, 2);
//
run(4, 8, 64, 19, 19 / 3, 2);
//
run(4, 8, 64, 21, 21 / 3, 2);
//
run(4, 8, 64, 23, 23 / 3, 2);
//
run(4, 8, 64, 25, 25 / 3, 2);
//
run(4, 8, 64, 27, 27 / 3, 2);
//
run(4, 8, 64, 29, 29 / 3, 2);
//
run(4, 8, 64, 31, 31 / 3, 2);
//
run(1, 2, 128, 31, 10, 2);
//
run(1, 2, 256, 31, 10, 2);
}
}
...
...
@@ -1638,10 +1638,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP32) {
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
using
NonlineMode
=
ConvBias
::
Param
::
NonlineMode
;
param
.
nonlineMode
=
NonlineMode
::
IDENTITY
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
run_bench
=
[
&
](
size_t
batch
,
size_t
g
,
size_t
hi
,
size_t
wi
,
size_t
fh
,
size_t
fw
,
size_t
sh
,
size_t
sw
,
size_t
nr_times
)
{
param
.
pad_h
=
fh
/
2
;
...
...
dnn/test/cuda/region_restricted_convolution.cpp
0 → 100644
浏览文件 @
543c9b77
#include "megdnn/dtype.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/common/conv_bias.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/common/workspace_wrapper.h"
#include "test/cuda/benchmark.h"
#include "test/cuda/fixture.h"
#include "test/cuda/utils.h"
#include <cudnn.h>
#define V1(x) #x
#define V(x) V1(x)
#define CUDNN_VERSION_STRING \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL)
namespace
megdnn
{
namespace
test
{
TEST_F
(
CUDA
,
REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER
)
{
Checker
<
RegionRestrictedConvolutionForward
>
checker
(
handle_cuda
());
auto
opr
=
handle_cuda
()
->
create_operator
<
ConvolutionForward
>
();
for
(
auto
dt
:
std
::
vector
<
DType
>
{
dtype
::
Int32
(),
dtype
::
Uint8
()})
{
auto
run
=
[
&
checker
,
&
dt
,
&
opr
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
RegionRestrictedConvolution
::
Param
cur_param
;
cur_param
.
mode
=
RegionRestrictedConvolution
::
Param
::
Mode
::
CROSS_CORRELATION
;
cur_param
.
sparse
=
RegionRestrictedConvolution
::
Param
::
Sparse
::
GROUP
;
checker
.
set_dtype
(
2
,
dt
).
set_dtype
(
3
,
dt
);
float
scale
=
64.
f
/
sqrt
(
fh
*
fh
);
UniformFloatRNG
rng
(
scale
,
2
*
scale
);
UniformIntRNG
r_rng
{
0
,
2
};
checker
.
set_rng
(
0
,
&
rng
).
set_rng
(
1
,
&
rng
).
set_rng
(
2
,
&
r_rng
).
set_rng
(
3
,
&
r_rng
);
if
(
dt
.
enumv
()
==
DTypeEnum
::
Float16
)
{
checker
.
set_epsilon
(
1e-1
);
}
cur_param
.
pad_h
=
cur_param
.
pad_w
=
padding
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
stride
;
size_t
ho
=
infer_conv_shape
(
h
,
fh
,
stride
,
padding
);
checker
.
set_param
(
cur_param
).
execs
(
{{
n
,
g
,
h
,
h
},
{
g
,
1
,
1
,
fh
,
fh
},
{
n
,
h
,
h
},
{
n
,
ho
,
ho
},
{}});
};
run
(
4
,
8
,
32
,
3
,
3
/
2
,
1
);
run
(
4
,
8
,
32
,
5
,
5
/
2
,
1
);
run
(
4
,
8
,
32
,
7
,
7
/
2
,
1
);
run
(
1
,
2
,
32
,
9
,
9
/
2
,
1
);
run
(
4
,
8
,
32
,
11
,
11
/
2
,
1
);
run
(
4
,
8
,
32
,
13
,
13
/
2
,
1
);
run
(
4
,
8
,
32
,
15
,
15
/
2
,
1
);
run
(
4
,
8
,
32
,
17
,
17
/
2
,
1
);
run
(
4
,
8
,
32
,
19
,
19
/
2
,
1
);
run
(
4
,
8
,
32
,
21
,
21
/
2
,
1
);
run
(
4
,
8
,
32
,
23
,
23
/
2
,
1
);
run
(
4
,
8
,
32
,
25
,
25
/
2
,
1
);
run
(
4
,
8
,
32
,
27
,
27
/
2
,
1
);
run
(
4
,
8
,
32
,
29
,
29
/
2
,
1
);
run
(
4
,
8
,
32
,
31
,
31
/
2
,
1
);
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
CUDA
,
BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
bencher
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
ConvBiasForward
::
algo_name
<
ConvBiasForward
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
Benchmarker
<
RegionRestrictedConvolutionForward
>
rr_bencher
(
handle_cuda
());
rr_bencher
.
set_display
(
false
);
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
using
NonlineMode
=
ConvBias
::
Param
::
NonlineMode
;
param
.
nonlineMode
=
NonlineMode
::
IDENTITY
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
RegionRestrictedConvolutionForward
::
Param
rr_param
;
rr_param
.
format
=
RegionRestrictedConvolutionForward
::
Param
::
Format
::
NCHW
;
rr_param
.
sparse
=
RegionRestrictedConvolutionForward
::
Param
::
Sparse
::
GROUP
;
UniformIntRNG
r_rng
{
0
,
2
};
auto
run_bench
=
[
&
](
size_t
batch
,
size_t
g
,
size_t
hi
,
size_t
wi
,
size_t
fh
,
size_t
fw
,
size_t
sh
,
size_t
sw
,
size_t
nr_times
)
{
param
.
pad_h
=
fh
/
2
;
param
.
pad_w
=
fw
/
2
;
param
.
stride_h
=
sh
;
param
.
stride_w
=
sw
;
rr_param
.
pad_h
=
fh
/
2
;
rr_param
.
pad_w
=
fw
/
2
;
rr_param
.
stride_h
=
sh
;
rr_param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
());
bencher
.
set_times
(
nr_times
);
rr_bencher
.
set_param
(
rr_param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Int32
())
.
set_dtype
(
3
,
dtype
::
Int32
());
rr_bencher
.
set_rng
(
2
,
&
r_rng
).
set_rng
(
3
,
&
r_rng
).
set_rng
(
0
,
&
r_rng
);
rr_bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
TensorShape
inp
{
batch
,
g
,
hi
,
wi
},
kern
{
g
,
1
,
1
,
fh
,
fw
},
rin
{
batch
,
hi
,
wi
},
rout
{
batch
,
ho
,
wo
},
out
{
batch
,
g
,
ho
,
wo
};
float
bandwith
=
static_cast
<
float
>
(
inp
.
total_nr_elems
()
+
kern
.
total_nr_elems
()
+
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
float
rr_bandwith
=
static_cast
<
float
>
(
inp
.
total_nr_elems
()
+
kern
.
total_nr_elems
()
+
rin
.
total_nr_elems
()
+
rout
.
total_nr_elems
()
+
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
auto
time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
auto
ops
=
2.0
*
batch
*
g
*
ho
*
wo
*
fh
*
fw
/
(
time_in_ms
*
1e-3
)
*
1e-12
;
auto
rr_time_in_ms
=
rr_bencher
.
execs
({
inp
,
kern
,
rin
,
rout
,
out
})
/
nr_times
;
auto
rr_ops
=
2.0
*
batch
*
g
*
ho
*
wo
*
fh
*
fw
/
(
rr_time_in_ms
*
1e-3
)
*
1e-12
;
printf
(
"RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: inp=%s, "
"kern=%s, out=%s
\n
"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops
\n
"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.
\n
"
,
inp
.
to_string
().
c_str
(),
kern
.
to_string
().
c_str
(),
out
.
to_string
().
c_str
(),
time_in_ms
,
rr_time_in_ms
,
ops
,
rr_ops
,
bandwith
*
4
/
time_in_ms
,
rr_bandwith
*
4
/
rr_time_in_ms
,
time_in_ms
/
rr_time_in_ms
);
};
run_bench
(
64
,
384
,
32
,
32
,
3
,
3
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
5
,
5
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
7
,
7
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
9
,
9
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
11
,
11
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
13
,
13
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
15
,
15
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
17
,
17
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
19
,
19
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
21
,
21
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
23
,
23
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
25
,
25
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
27
,
27
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
29
,
29
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
31
,
31
,
1
,
1
,
10
);
}
TEST_F
(
CUDA
,
BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
bencher
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
ConvBiasForward
::
algo_name
<
ConvBiasForward
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
Benchmarker
<
RegionRestrictedConvolutionForward
>
rr_bencher
(
handle_cuda
());
rr_bencher
.
set_display
(
false
);
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
using
NonlineMode
=
ConvBias
::
Param
::
NonlineMode
;
param
.
nonlineMode
=
NonlineMode
::
IDENTITY
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
RegionRestrictedConvolutionForward
::
Param
rr_param
;
rr_param
.
format
=
RegionRestrictedConvolutionForward
::
Param
::
Format
::
NCHW
;
rr_param
.
sparse
=
RegionRestrictedConvolutionForward
::
Param
::
Sparse
::
GROUP
;
UniformIntRNG
r_rng
{
0
,
2
};
auto
run_bench
=
[
&
](
size_t
batch
,
size_t
g
,
size_t
hi
,
size_t
wi
,
size_t
fh
,
size_t
fw
,
size_t
sh
,
size_t
sw
,
size_t
nr_times
)
{
param
.
pad_h
=
fh
/
2
;
param
.
pad_w
=
fw
/
2
;
param
.
stride_h
=
sh
;
param
.
stride_w
=
sw
;
rr_param
.
pad_h
=
fh
/
2
;
rr_param
.
pad_w
=
fw
/
2
;
rr_param
.
stride_h
=
sh
;
rr_param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
());
bencher
.
set_times
(
nr_times
);
rr_bencher
.
set_param
(
rr_param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Uint8
())
.
set_dtype
(
3
,
dtype
::
Uint8
());
rr_bencher
.
set_rng
(
2
,
&
r_rng
).
set_rng
(
3
,
&
r_rng
).
set_rng
(
0
,
&
r_rng
);
rr_bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
TensorShape
inp
{
batch
,
g
,
hi
,
wi
},
kern
{
g
,
1
,
1
,
fh
,
fw
},
rin
{
batch
,
hi
,
wi
},
rout
{
batch
,
ho
,
wo
},
out
{
batch
,
g
,
ho
,
wo
};
float
bandwith
=
static_cast
<
float
>
(
inp
.
total_nr_elems
()
+
kern
.
total_nr_elems
()
+
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
float
rr_bandwith
=
static_cast
<
float
>
(
inp
.
total_nr_elems
()
+
kern
.
total_nr_elems
()
+
rin
.
total_nr_elems
()
+
rout
.
total_nr_elems
()
+
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
auto
time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
auto
ops
=
2.0
*
batch
*
g
*
ho
*
wo
*
fh
*
fw
/
(
time_in_ms
*
1e-3
)
*
1e-12
;
auto
rr_time_in_ms
=
rr_bencher
.
execs
({
inp
,
kern
,
rin
,
rout
,
out
})
/
nr_times
;
auto
rr_ops
=
2.0
*
batch
*
g
*
ho
*
wo
*
fh
*
fw
/
(
rr_time_in_ms
*
1e-3
)
*
1e-12
;
printf
(
"RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: inp=%s, "
"kern=%s, out=%s
\n
"
"time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops
\n
"
"bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.
\n
"
,
inp
.
to_string
().
c_str
(),
kern
.
to_string
().
c_str
(),
out
.
to_string
().
c_str
(),
time_in_ms
,
rr_time_in_ms
,
ops
,
rr_ops
,
bandwith
*
4
/
time_in_ms
,
rr_bandwith
*
4
/
rr_time_in_ms
,
time_in_ms
/
rr_time_in_ms
);
};
run_bench
(
64
,
384
,
32
,
32
,
3
,
3
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
5
,
5
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
7
,
7
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
9
,
9
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
11
,
11
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
13
,
13
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
15
,
15
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
17
,
17
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
19
,
19
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
21
,
21
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
23
,
23
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
25
,
25
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
27
,
27
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
29
,
29
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
31
,
31
,
1
,
1
,
10
);
}
#endif
}
// namespace test
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/test/naive/region_restricted_convolution.cpp
浏览文件 @
543c9b77
...
...
@@ -11,8 +11,8 @@ using namespace megdnn;
using
namespace
test
;
namespace
{
void
mask_tensor
(
template
<
typename
rtype
>
void
mask_tensor
_kernel
(
const
TensorND
&
in
,
TensorND
&
out
,
const
TensorND
&
mask
,
const
int32_t
mask_val
)
{
megdnn_assert
(
...
...
@@ -23,7 +23,7 @@ void mask_tensor(
mask
.
layout
[
0
]
==
in
.
layout
[
0
]
&&
mask
.
layout
[
1
]
==
in
.
layout
[
2
]
&&
mask
.
layout
[
2
]
==
in
.
layout
[
3
]);
int32_t
*
mask_ptr
=
mask
.
ptr
<
int32_t
>
();
rtype
*
mask_ptr
=
mask
.
compatible_ptr
<
rtype
>
();
float
*
src_ptr
=
in
.
compatible_ptr
<
float
>
();
float
*
dst_ptr
=
out
.
compatible_ptr
<
float
>
();
...
...
@@ -47,6 +47,16 @@ void mask_tensor(
}
}
}
void
mask_tensor
(
const
TensorND
&
in
,
TensorND
&
out
,
const
TensorND
&
mask
,
const
int32_t
mask_val
)
{
if
(
mask
.
layout
.
dtype
==
dtype
::
Int32
())
{
mask_tensor_kernel
<
dt_int32
>
(
in
,
out
,
mask
,
mask_val
);
}
else
if
(
mask
.
layout
.
dtype
==
dtype
::
Uint8
())
{
mask_tensor_kernel
<
dt_uint8
>
(
in
,
out
,
mask
,
mask_val
);
}
}
}
// namespace
TEST_F
(
NAIVE
,
REGIONRESTRICTEDCONVOLUTION_FORWARD
)
{
...
...
@@ -54,7 +64,7 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
RegionRestrictedConvolution
::
Param
param
;
constexpr
int
N
=
3
;
UniformIntRNG
rng
{
0
,
N
-
1
};
UniformIntRNG
rng
{
0
,
N
-
1
};
auto
extra_impl
=
[
&
,
this
](
const
TensorNDArray
&
tensors
)
{
auto
conv
=
handle
()
->
create_operator
<
Convolution
>
();
...
...
@@ -64,16 +74,17 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
dt_byte
*
workspace_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
workspace_size
));
Workspace
workspace
{
workspace_ptr
,
workspace_size
};
TensorND
masked_src
(
malloc
(
tensors
[
0
].
layout
.
span
().
dist_byte
()),
tensors
[
0
].
layout
);
TensorND
masked_src
(
malloc
(
tensors
[
0
].
layout
.
span
().
dist_byte
()),
tensors
[
0
].
layout
);
TensorNDArray
dst_tensors
;
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst_tensors
.
emplace_back
(
malloc
(
tensors
[
4
].
layout
.
span
().
dist_byte
()),
tensors
[
4
].
layout
);
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
dst_tensors
.
emplace_back
(
malloc
(
tensors
[
4
].
layout
.
span
().
dist_byte
()),
tensors
[
4
].
layout
);
}
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
mask_tensor
(
tensors
[
0
],
masked_src
,
tensors
[
2
],
i
);
conv
->
exec
(
masked_src
,
tensors
[
1
],
dst_tensors
[
i
],
nullptr
,
workspace
);
mask_tensor
(
dst_tensors
[
i
],
dst_tensors
[
i
],
tensors
[
3
],
i
);
}
free
(
workspace_ptr
);
...
...
@@ -81,7 +92,7 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
auto
add
=
handle
()
->
create_operator
<
ElemwiseForward
>
();
add
->
param
().
mode
=
Mode
::
ADD
;
add
->
exec
({
dst_tensors
[
0
],
dst_tensors
[
1
]},
tensors
[
4
]);
for
(
int
i
=
2
;
i
<
N
;
++
i
)
{
for
(
int
i
=
2
;
i
<
N
;
++
i
)
{
add
->
exec
({
dst_tensors
[
i
],
tensors
[
4
]},
tensors
[
4
]);
}
};
...
...
@@ -92,6 +103,12 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
.
set_dtype
(
2
,
dtype
::
Int32
())
.
set_dtype
(
3
,
dtype
::
Int32
());
checker
.
execs
({{
1
,
8
,
2
,
2
},
{
4
,
8
,
1
,
1
},
{
1
,
2
,
2
},
{
1
,
2
,
2
},
{}})
.
execs
({{
20
,
12
,
30
,
30
},
{
4
,
12
,
1
,
1
},
{
20
,
30
,
30
},
{
20
,
30
,
30
},
{}})
.
execs
({{
20
,
8
,
30
,
30
},
{
4
,
8
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}});
checker
.
set_dtype
(
2
,
dtype
::
Uint8
()).
set_dtype
(
3
,
dtype
::
Uint8
());
checker
.
execs
({{
1
,
8
,
2
,
2
},
{
4
,
8
,
1
,
1
},
{
1
,
2
,
2
},
{
1
,
2
,
2
},
{}})
.
execs
({{
20
,
12
,
30
,
30
},
{
4
,
12
,
1
,
1
},
{
20
,
30
,
30
},
{
20
,
30
,
30
},
{}})
.
execs
({{
20
,
8
,
30
,
30
},
{
4
,
8
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}});
...
...
@@ -99,100 +116,19 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) {
param
.
sparse
=
Convolution
::
Param
::
Sparse
::
GROUP
;
checker
.
set_param
(
param
)
.
execs
({{
20
,
15
,
30
,
30
},
{
5
,
4
,
3
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}})
.
execs
({{
20
,
25
,
30
,
30
},
{
25
,
1
,
1
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}});
}
#if 0
TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BACKWARD_DATA) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
using Param = RegionRestrictedConvolutionBackwardData::Param;
Param param;
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh,
size_t fw, size_t stride, size_t padding, size_t dilate = 1,
size_t group = 1) {
param.pad_h = param.pad_w = padding;
param.stride_h = param.stride_w = stride;
param.dilate_h = param.dilate_w = dilate;
TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()};
TensorLayout grad;
TensorLayout filter;
if (group == 1) {
param.sparse = Param::Sparse::DENSE;
filter = {{oc, ic, fh, fw}, dtype::Float32()};
} else {
param.sparse = Param::Sparse::GROUP;
filter = {{group, oc, ic, fh, fw}, dtype::Float32()};
}
// TensorLayout grad;
{
auto opr = handle()->create_operator<ConvolutionBackwardData>();
opr->param() = param;
opr->deduce_layout(filter, diff, grad);
}
checker.set_param(param);
checker.exec(TensorLayoutArray{filter, diff, grad});
};
for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) {
param.mode = mode;
run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2);
run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3);
run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2);
run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2);
run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3);
run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2);
}
}
TEST_F(NAIVE, CONVOLUTION_BACKWARD_DATA) {
Checker<RegionRestrictedConvolutionBackwardData> checker(handle());
using Param = RegionRestrictedConvolutionBackwardData::Param;
Param param;
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh,
size_t fw, size_t stride, size_t padding, size_t dilate = 1,
size_t group = 1) {
param.pad_h = param.pad_w = padding;
param.stride_h = param.stride_w = stride;
param.dilate_h = param.dilate_w = dilate;
TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()};
TensorLayout grad;
TensorLayout filter;
if (group == 1) {
param.sparse = Param::Sparse::DENSE;
filter = {{oc, ic, fh, fw}, dtype::Float32()};
} else {
param.sparse = Param::Sparse::GROUP;
filter = {{group, oc, ic, fh, fw}, dtype::Float32()};
}
// TensorLayout grad;
{
auto opr = handle()->create_operator<ConvolutionBackwardData>();
opr->param() = param;
opr->deduce_layout(filter, diff, grad);
}
checker.set_param(param);
checker.exec(TensorLayoutArray{filter, diff, grad});
};
for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) {
param.mode = mode;
run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2);
run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3);
run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2);
run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3);
run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2);
run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3);
run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2);
}
.
execs
({{
20
,
25
,
30
,
30
},
{
25
,
1
,
1
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}});
checker
.
set_dtype
(
2
,
dtype
::
Int32
()).
set_dtype
(
3
,
dtype
::
Int32
());
checker
.
execs
({{
20
,
15
,
30
,
30
},
{
5
,
4
,
3
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}})
.
execs
({{
20
,
25
,
30
,
30
},
{
25
,
1
,
1
,
3
,
3
},
{
20
,
30
,
30
},
{
20
,
28
,
28
},
{}});
}
#endif
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录