Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
73ca2e00
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73ca2e00
编写于
3月 24, 2020
作者:
X
xiebaiyuan
提交者:
GitHub
3月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[LITE][OPENCL][Image]support multi batch conv2d 3x3 5x5 7x7 ,open lws,test=develop (#3258)
上级
745e99f3
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
876 addition
and
76 deletion
+876
-76
.gitignore
.gitignore
+2
-0
lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl
.../backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl
+347
-59
lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl
.../backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl
+252
-0
lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl
.../backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl
+252
-0
lite/kernels/opencl/conv_image_compute.cc
lite/kernels/opencl/conv_image_compute.cc
+8
-4
lite/kernels/opencl/conv_image_compute.h
lite/kernels/opencl/conv_image_compute.h
+1
-1
lite/kernels/opencl/conv_image_compute_test.cc
lite/kernels/opencl/conv_image_compute_test.cc
+14
-12
未找到文件。
.gitignore
浏览文件 @
73ca2e00
...
...
@@ -105,3 +105,5 @@ metal/paddle-mobile-demo/paddle-mobile-demo/Resources
metal/paddle-mobile-demo/paddle-mobile-demo/Resources/images
metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models
metal/MobileNetDemo/MobileNetDemo/Resources
build*
lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl
浏览文件 @
73ca2e00
...
...
@@ -61,7 +61,8 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
#ifdef BIASE_CH
CL_DTYPE4 output[5];
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
output[1] = output[0];
output[2] = output[0];
output[3] = output[0];
...
...
@@ -70,22 +71,32 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
#elif defined(BIASE_ELE)
CL_DTYPE4 output[5];
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_w_base_id + out_w_id0, item_h_id));
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
#else
...
...
@@ -109,54 +120,76 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
int filter_w_val = ch * 3;
for (int h = 0; h < 3; h++) {
int in_h_val = select(out_batch_id * in_h + in_h_id + h, -1,
int in_h_val = select(out_batch_id * in_h + in_h_id + h,
-1,
(out_batch_id * in_h + in_h_id + h < 0 ||
out_batch_id * in_h + in_h_id + h >= in_h));
for (int w = 0; w < 3; w++) {
int in_w_val0 = select(in_w_base_id + in_w_id0 + w, -1,
int in_w_val0 = select(in_w_base_id + in_w_id0 + w,
-1,
(in_w_id0 + w < 0 || in_w_id0 + w >= in_w));
int in_w_val1 = select(in_w_base_id + in_w_id1 + w, -1,
int in_w_val1 = select(in_w_base_id + in_w_id1 + w,
-1,
(in_w_id1 + w < 0 || in_w_id1 + w >= in_w));
int in_w_val2 = select(in_w_base_id + in_w_id2 + w, -1,
int in_w_val2 = select(in_w_base_id + in_w_id2 + w,
-1,
(in_w_id2 + w < 0 || in_w_id2 + w >= in_w));
int in_w_val3 = select(in_w_base_id + in_w_id3 + w, -1,
int in_w_val3 = select(in_w_base_id + in_w_id3 + w,
-1,
(in_w_id3 + w < 0 || in_w_id3 + w >= in_w));
int in_w_val4 = select(in_w_base_id + in_w_id4 + w, -1,
int in_w_val4 = select(in_w_base_id + in_w_id4 + w,
-1,
(in_w_id4 + w < 0 || in_w_id4 + w >= in_w));
filter[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
filter[0] = READ_IMG_TYPE(
CL_DTYPE_CHAR,
filter_image,
sampler,
(int2)(filter_w_val + w, filter_h_val0 + h)); // in_ch:0-3,out_ch:0
filter[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
filter[1] = READ_IMG_TYPE(
CL_DTYPE_CHAR,
filter_image,
sampler,
(int2)(filter_w_val + w, filter_h_val1 + h)); // in_ch:0-3,out_ch:1
filter[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
filter[2] = READ_IMG_TYPE(
CL_DTYPE_CHAR,
filter_image,
sampler,
(int2)(filter_w_val + w, filter_h_val2 + h)); // in_ch:0-3,out_ch:2
filter[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
filter[3] = READ_IMG_TYPE(
CL_DTYPE_CHAR,
filter_image,
sampler,
(int2)(filter_w_val + w, filter_h_val3 + h)); // in_ch:0-3,out_ch:3
filter_trans[0]
=
(
CL_DTYPE4
)(
filter[0].x,
filter[1].x,
filter[2].x,
filter_trans[0] = (CL_DTYPE4)(filter[0].x,
filter[1].x,
filter[2].x,
filter[3].x); // in_ch:0,out_ch:0-3
filter_trans[1]
=
(
CL_DTYPE4
)(
filter[0].y,
filter[1].y,
filter[2].y,
filter_trans[1] = (CL_DTYPE4)(filter[0].y,
filter[1].y,
filter[2].y,
filter[3].y); // in_ch:1,out_ch:0-3
filter_trans[2]
=
(
CL_DTYPE4
)(
filter[0].z,
filter[1].z,
filter[2].z,
filter_trans[2] = (CL_DTYPE4)(filter[0].z,
filter[1].z,
filter[2].z,
filter[3].z); // in_ch:2,out_ch:0-3
filter_trans[3]
=
(
CL_DTYPE4
)(
filter[0].w,
filter[1].w,
filter[2].w,
filter_trans[3] = (CL_DTYPE4)(filter[0].w,
filter[1].w,
filter[2].w,
filter[3].w); // in_ch:3,out_ch:0-3
input[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val0,
in_h_val
))
;
input[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val1,
in_h_val
))
;
input[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val2,
in_h_val
))
;
input[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val3,
in_h_val
))
;
input[4]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val4,
in_h_val
))
;
input[0] =
READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val));
input[1] =
READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val));
input[2] =
READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val));
input[3] =
READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val));
input[4] =
READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val));
output[0] = mad(input[0].x, filter_trans[0], output[0]);
output[1] = mad(input[1].x, filter_trans[0], output[1]);
...
...
@@ -195,23 +228,278 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
output[3] = activation_type4(output[3]);
output[4] = activation_type4(output[4]);
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id0,
item_h_id
)
,
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id0, item_h_id),
output[0]);
if (out_w_id1 < out_w) {
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id1,
item_h_id
)
,
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id1, item_h_id),
output[1]);
}
if (out_w_id2 < out_w) {
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id2,
item_h_id
)
,
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id2, item_h_id),
output[2]);
}
if (out_w_id3 < out_w) {
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id3,
item_h_id
)
,
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id3, item_h_id),
output[3]);
}
if (out_w_id4 < out_w) {
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id4,
item_h_id
)
,
WRITE_IMG_TYPE(CL_DTYPE_CHAR,
output_image,
(int2)(out_w_base_id + out_w_id4, item_h_id),
output[4]);
}
}
// support batch > 1
__kernel void conv2d_3x3_multi_batch(__private const int item_ch,
__private const int item_w,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
__private const int dilation,
__private const int batch,
__private const int in_ch,
__private const int in_w,
__private const int in_h,
__private const int out_w,
__private const int out_h) {
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP
| CLK_FILTER_NEAREST;
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
const int item_h_id = get_global_id(2);
// out_width_id_per_blk and out_batch_id
int out_batch_id = item_h_id / in_h;
int out_w_base_id = item_ch_id * out_w;
int out_w_id0 = item_w_id;
int out_w_id1 = out_w_id0 + item_w;
int out_w_id2 = out_w_id1 + item_w;
int out_w_id3 = out_w_id2 + item_w;
int out_w_id4 = out_w_id3 + item_w;
// in_width_id_per_blk and in_height_id_per_batch
int in_h_id = (item_h_id % out_h) * stride - pad;
int in_w_id0 = item_w_id * stride - pad;
int in_w_id1 = in_w_id0 + item_w * stride;
int in_w_id2 = in_w_id1 + item_w * stride;
int in_w_id3 = in_w_id2 + item_w * stride;
int in_w_id4 = in_w_id3 + item_w * stride;
#ifdef BIASE_CH
CL_DTYPE4 output[5];
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
output[1] = output[0];
output[2] = output[0];
output[3] = output[0];
output[4] = output[0];
#elif defined(BIASE_ELE)
CL_DTYPE4 output[5];
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
#else
CL_DTYPE4 output[5] = {0.0f};
#endif
CL_DTYPE4 filter[4] = {0.0f};
CL_DTYPE4 filter_trans[4] = {0.0f};
CL_DTYPE4 input[5] = {0.0f};
int filter_h_val0 = item_ch_id * 4 * 3;
int filter_h_val1 = filter_h_val0 + 3;
int filter_h_val2 = filter_h_val1 + 3;
int filter_h_val3 = filter_h_val2 + 3;
for (int ch = 0; ch < (in_ch + 3) / 4; ch++) {
int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0;
const int in_w_base_id = mul24(ch, in_w);
int filter_w_val = ch * 3;
for (int h = 0; h < 3; h++) {
int in_h_val = select(
out_batch_id * in_h + in_h_id + h,
-1,
(out_batch_id * in_h + in_h_id + h < out_batch_id * in_h ||
out_batch_id * in_h + in_h_id + h >= (out_batch_id + 1) * in_h));
for (int w = 0; w < 3; w++) {
int in_w_val0 = select(in_w_base_id + in_w_id0 + w,
-1,
(in_w_id0 + w < 0 || in_w_id0 + w >= in_w));
int in_w_val1 = select(in_w_base_id + in_w_id1 + w,
-1,
(in_w_id1 + w < 0 || in_w_id1 + w >= in_w));
int in_w_val2 = select(in_w_base_id + in_w_id2 + w,
-1,
(in_w_id2 + w < 0 || in_w_id2 + w >= in_w));
int in_w_val3 = select(in_w_base_id + in_w_id3 + w,
-1,
(in_w_id3 + w < 0 || in_w_id3 + w >= in_w));
int in_w_val4 = select(in_w_base_id + in_w_id4 + w,
-1,
(in_w_id4 + w < 0 |
|
in_w_id4
+
w
>=
in_w
)
)
;
filter[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val0
+
h
))
; // in_ch:0-3,out_ch:0
filter[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val1
+
h
))
; // in_ch:0-3,out_ch:1
filter[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val2
+
h
))
; // in_ch:0-3,out_ch:2
filter[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val3
+
h
))
; // in_ch:0-3,out_ch:3
filter_trans[0]
=
(
CL_DTYPE4
)(
filter[0].x,
filter[1].x,
filter[2].x,
filter[3].x
)
; // in_ch:0,out_ch:0-3
filter_trans[1]
=
(
CL_DTYPE4
)(
filter[0].y,
filter[1].y,
filter[2].y,
filter[3].y
)
; // in_ch:1,out_ch:0-3
filter_trans[2]
=
(
CL_DTYPE4
)(
filter[0].z,
filter[1].z,
filter[2].z,
filter[3].z
)
; // in_ch:2,out_ch:0-3
filter_trans[3]
=
(
CL_DTYPE4
)(
filter[0].w,
filter[1].w,
filter[2].w,
filter[3].w
)
; // in_ch:3,out_ch:0-3
input[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val0,
in_h_val
))
;
input[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val1,
in_h_val
))
;
input[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val2,
in_h_val
))
;
input[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val3,
in_h_val
))
;
input[4]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val4,
in_h_val
))
;
output[0]
=
mad
(
input[0].x,
filter_trans[0],
output[0]
)
;
output[1]
=
mad
(
input[1].x,
filter_trans[0],
output[1]
)
;
output[2]
=
mad
(
input[2].x,
filter_trans[0],
output[2]
)
;
output[3]
=
mad
(
input[3].x,
filter_trans[0],
output[3]
)
;
output[4]
=
mad
(
input[4].x,
filter_trans[0],
output[4]
)
;
if
(
ch_surplus
<
3
)
{
output[0]
=
mad
(
input[0].y,
filter_trans[1],
output[0]
)
;
output[1]
=
mad
(
input[1].y,
filter_trans[1],
output[1]
)
;
output[2]
=
mad
(
input[2].y,
filter_trans[1],
output[2]
)
;
output[3]
=
mad
(
input[3].y,
filter_trans[1],
output[3]
)
;
output[4]
=
mad
(
input[4].y,
filter_trans[1],
output[4]
)
;
}
if
(
ch_surplus
<
2
)
{
output[0]
=
mad
(
input[0].z,
filter_trans[2],
output[0]
)
;
output[1]
=
mad
(
input[1].z,
filter_trans[2],
output[1]
)
;
output[2]
=
mad
(
input[2].z,
filter_trans[2],
output[2]
)
;
output[3]
=
mad
(
input[3].z,
filter_trans[2],
output[3]
)
;
output[4]
=
mad
(
input[4].z,
filter_trans[2],
output[4]
)
;
}
if
(
ch_surplus
<
1
)
{
output[0]
=
mad
(
input[0].w,
filter_trans[3],
output[0]
)
;
output[1]
=
mad
(
input[1].w,
filter_trans[3],
output[1]
)
;
output[2]
=
mad
(
input[2].w,
filter_trans[3],
output[2]
)
;
output[3]
=
mad
(
input[3].w,
filter_trans[3],
output[3]
)
;
output[4]
=
mad
(
input[4].w,
filter_trans[3],
output[4]
)
;
}
}
}
}
output[0]
=
activation_type4
(
output[0]
)
;
output[1]
=
activation_type4
(
output[1]
)
;
output[2]
=
activation_type4
(
output[2]
)
;
output[3]
=
activation_type4
(
output[3]
)
;
output[4]
=
activation_type4
(
output[4]
)
;
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id0,
item_h_id
)
,
output[0]
)
;
if
(
out_w_id1
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id1,
item_h_id
)
,
output[1]
)
;
}
if
(
out_w_id2
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id2,
item_h_id
)
,
output[2]
)
;
}
if
(
out_w_id3
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id3,
item_h_id
)
,
output[3]
)
;
}
if
(
out_w_id4
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id4,
item_h_id
)
,
output[4]
)
;
}
}
lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl
浏览文件 @
73ca2e00
...
...
@@ -262,3 +262,255 @@ __kernel void conv2d_5x5_opt(__private const int item_ch,
output[4]);
}
}
// support batch > 1
__kernel void conv2d_5x5_multi_batch(__private const int item_ch,
__private const int item_w,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
__private const int dilation,
__private const int batch,
__private const int in_ch,
__private const int in_w,
__private const int in_h,
__private const int out_w,
__private const int out_h) {
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP
| CLK_FILTER_NEAREST;
// filter
const int filter_w = 5;
const int filter_h = 5;
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
const int item_h_id = get_global_id(2);
// out_width_id_per_blk and out_batch_id
int out_batch_id = item_h_id / in_h;
int out_w_base_id = item_ch_id * out_w;
int out_w_id0 = item_w_id;
int out_w_id1 = out_w_id0 + item_w;
int out_w_id2 = out_w_id1 + item_w;
int out_w_id3 = out_w_id2 + item_w;
int out_w_id4 = out_w_id3 + item_w;
// in_width_id_per_blk and in_height_id_per_batch
int in_h_id = (item_h_id % out_h) * stride - pad;
int in_w_id0 = item_w_id * stride - pad;
int in_w_id1 = in_w_id0 + item_w * stride;
int in_w_id2 = in_w_id1 + item_w * stride;
int in_w_id3 = in_w_id2 + item_w * stride;
int in_w_id4 = in_w_id3 + item_w * stride;
#ifdef BIASE_CH
CL_DTYPE4 output[5];
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
output[1] = output[0];
output[2] = output[0];
output[3] = output[0];
output[4] = output[0];
#elif defined(BIASE_ELE)
CL_DTYPE4 output[5];
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
#else
CL_DTYPE4 output[5] = {0.0f};
#endif
CL_DTYPE4 filter[4] = {0.0f};
CL_DTYPE4 filter_trans[4] = {0.0f};
CL_DTYPE4 input[5] = {0.0f};
int filter_h_val0 = item_ch_id * 4 * filter_h;
int filter_h_val1 = filter_h_val0 + filter_h;
int filter_h_val2 = filter_h_val1 + filter_h;
int filter_h_val3 = filter_h_val2 + filter_h;
for (int ch = 0; ch < (in_ch + 3) / 4; ch++) {
int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0;
const int in_w_base_id = mul24(ch, in_w);
int filter_w_val = ch * filter_w;
for (int h = 0; h < filter_h; h++) {
int in_h_val = select(
out_batch_id * in_h + in_h_id + h,
-1,
(out_batch_id * in_h + in_h_id + h < out_batch_id * in_h ||
out_batch_id * in_h + in_h_id + h >= (out_batch_id + 1) * in_h));
for (int w = 0; w < filter_w; w++) {
int in_w_val0 = select(in_w_base_id + in_w_id0 + w,
-1,
(in_w_id0 + w < 0 || in_w_id0 + w >= in_w));
int in_w_val1 = select(in_w_base_id + in_w_id1 + w,
-1,
(in_w_id1 + w < 0 || in_w_id1 + w >= in_w));
int in_w_val2 = select(in_w_base_id + in_w_id2 + w,
-1,
(in_w_id2 + w < 0 || in_w_id2 + w >= in_w));
int in_w_val3 = select(in_w_base_id + in_w_id3 + w,
-1,
(in_w_id3 + w < 0 || in_w_id3 + w >= in_w));
int in_w_val4 = select(in_w_base_id + in_w_id4 + w,
-1,
(in_w_id4 + w < 0 |
|
in_w_id4
+
w
>=
in_w
)
)
;
filter[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val0
+
h
))
; // in_ch:0-3,out_ch:0
filter[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val1
+
h
))
; // in_ch:0-3,out_ch:1
filter[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val2
+
h
))
; // in_ch:0-3,out_ch:2
filter[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val3
+
h
))
; // in_ch:0-3,out_ch:3
filter_trans[0]
=
(
CL_DTYPE4
)(
filter[0].x,
filter[1].x,
filter[2].x,
filter[3].x
)
; // in_ch:0,out_ch:0-3
filter_trans[1]
=
(
CL_DTYPE4
)(
filter[0].y,
filter[1].y,
filter[2].y,
filter[3].y
)
; // in_ch:1,out_ch:0-3
filter_trans[2]
=
(
CL_DTYPE4
)(
filter[0].z,
filter[1].z,
filter[2].z,
filter[3].z
)
; // in_ch:2,out_ch:0-3
filter_trans[3]
=
(
CL_DTYPE4
)(
filter[0].w,
filter[1].w,
filter[2].w,
filter[3].w
)
; // in_ch:3,out_ch:0-3
input[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val0,
in_h_val
))
;
input[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val1,
in_h_val
))
;
input[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val2,
in_h_val
))
;
input[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val3,
in_h_val
))
;
input[4]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val4,
in_h_val
))
;
output[0]
=
mad
(
input[0].x,
filter_trans[0],
output[0]
)
;
output[1]
=
mad
(
input[1].x,
filter_trans[0],
output[1]
)
;
output[2]
=
mad
(
input[2].x,
filter_trans[0],
output[2]
)
;
output[3]
=
mad
(
input[3].x,
filter_trans[0],
output[3]
)
;
output[4]
=
mad
(
input[4].x,
filter_trans[0],
output[4]
)
;
if
(
ch_surplus
<
3
)
{
output[0]
=
mad
(
input[0].y,
filter_trans[1],
output[0]
)
;
output[1]
=
mad
(
input[1].y,
filter_trans[1],
output[1]
)
;
output[2]
=
mad
(
input[2].y,
filter_trans[1],
output[2]
)
;
output[3]
=
mad
(
input[3].y,
filter_trans[1],
output[3]
)
;
output[4]
=
mad
(
input[4].y,
filter_trans[1],
output[4]
)
;
}
if
(
ch_surplus
<
2
)
{
output[0]
=
mad
(
input[0].z,
filter_trans[2],
output[0]
)
;
output[1]
=
mad
(
input[1].z,
filter_trans[2],
output[1]
)
;
output[2]
=
mad
(
input[2].z,
filter_trans[2],
output[2]
)
;
output[3]
=
mad
(
input[3].z,
filter_trans[2],
output[3]
)
;
output[4]
=
mad
(
input[4].z,
filter_trans[2],
output[4]
)
;
}
if
(
ch_surplus
<
1
)
{
output[0]
=
mad
(
input[0].w,
filter_trans[3],
output[0]
)
;
output[1]
=
mad
(
input[1].w,
filter_trans[3],
output[1]
)
;
output[2]
=
mad
(
input[2].w,
filter_trans[3],
output[2]
)
;
output[3]
=
mad
(
input[3].w,
filter_trans[3],
output[3]
)
;
output[4]
=
mad
(
input[4].w,
filter_trans[3],
output[4]
)
;
}
}
}
}
output[0]
=
activation_type4
(
output[0]
)
;
output[1]
=
activation_type4
(
output[1]
)
;
output[2]
=
activation_type4
(
output[2]
)
;
output[3]
=
activation_type4
(
output[3]
)
;
output[4]
=
activation_type4
(
output[4]
)
;
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id0,
item_h_id
)
,
output[0]
)
;
if
(
out_w_id1
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id1,
item_h_id
)
,
output[1]
)
;
}
if
(
out_w_id2
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id2,
item_h_id
)
,
output[2]
)
;
}
if
(
out_w_id3
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id3,
item_h_id
)
,
output[3]
)
;
}
if
(
out_w_id4
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id4,
item_h_id
)
,
output[4]
)
;
}
}
\ No newline at end of file
lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl
浏览文件 @
73ca2e00
...
...
@@ -262,3 +262,255 @@ __kernel void conv2d_7x7_opt(__private const int item_ch,
output[4]);
}
}
// support batch > 1
__kernel void conv2d_7x7_multi_batch(__private const int item_ch,
__private const int item_w,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
__private const int dilation,
__private const int batch,
__private const int in_ch,
__private const int in_w,
__private const int in_h,
__private const int out_w,
__private const int out_h) {
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP
| CLK_FILTER_NEAREST;
// filter
const int filter_w = 7;
const int filter_h = 7;
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
const int item_h_id = get_global_id(2);
// out_width_id_per_blk and out_batch_id
int out_batch_id = item_h_id / in_h;
int out_w_base_id = item_ch_id * out_w;
int out_w_id0 = item_w_id;
int out_w_id1 = out_w_id0 + item_w;
int out_w_id2 = out_w_id1 + item_w;
int out_w_id3 = out_w_id2 + item_w;
int out_w_id4 = out_w_id3 + item_w;
// in_width_id_per_blk and in_height_id_per_batch
int in_h_id = (item_h_id % out_h) * stride - pad;
int in_w_id0 = item_w_id * stride - pad;
int in_w_id1 = in_w_id0 + item_w * stride;
int in_w_id2 = in_w_id1 + item_w * stride;
int in_w_id3 = in_w_id2 + item_w * stride;
int in_w_id4 = in_w_id3 + item_w * stride;
#ifdef BIASE_CH
CL_DTYPE4 output[5];
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
output[1] = output[0];
output[2] = output[0];
output[3] = output[0];
output[4] = output[0];
#elif defined(BIASE_ELE)
CL_DTYPE4 output[5];
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR,
bias,
sampler,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
#else
CL_DTYPE4 output[5] = {0.0f};
#endif
CL_DTYPE4 filter[4] = {0.0f};
CL_DTYPE4 filter_trans[4] = {0.0f};
CL_DTYPE4 input[5] = {0.0f};
int filter_h_val0 = item_ch_id * 4 * filter_h;
int filter_h_val1 = filter_h_val0 + filter_h;
int filter_h_val2 = filter_h_val1 + filter_h;
int filter_h_val3 = filter_h_val2 + filter_h;
for (int ch = 0; ch < (in_ch + 3) / 4; ch++) {
int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0;
const int in_w_base_id = mul24(ch, in_w);
int filter_w_val = ch * filter_w;
for (int h = 0; h < filter_h; h++) {
int in_h_val = select(
out_batch_id * in_h + in_h_id + h,
-1,
(out_batch_id * in_h + in_h_id + h < out_batch_id * in_h ||
out_batch_id * in_h + in_h_id + h >= (out_batch_id + 1) * in_h));
for (int w = 0; w < filter_w; w++) {
int in_w_val0 = select(in_w_base_id + in_w_id0 + w,
-1,
(in_w_id0 + w < 0 || in_w_id0 + w >= in_w));
int in_w_val1 = select(in_w_base_id + in_w_id1 + w,
-1,
(in_w_id1 + w < 0 || in_w_id1 + w >= in_w));
int in_w_val2 = select(in_w_base_id + in_w_id2 + w,
-1,
(in_w_id2 + w < 0 || in_w_id2 + w >= in_w));
int in_w_val3 = select(in_w_base_id + in_w_id3 + w,
-1,
(in_w_id3 + w < 0 || in_w_id3 + w >= in_w));
int in_w_val4 = select(in_w_base_id + in_w_id4 + w,
-1,
(in_w_id4 + w < 0 |
|
in_w_id4
+
w
>=
in_w
)
)
;
filter[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val0
+
h
))
; // in_ch:0-3,out_ch:0
filter[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val1
+
h
))
; // in_ch:0-3,out_ch:1
filter[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val2
+
h
))
; // in_ch:0-3,out_ch:2
filter[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
filter_image,
sampler,
(
int2
)(
filter_w_val
+
w,
filter_h_val3
+
h
))
; // in_ch:0-3,out_ch:3
filter_trans[0]
=
(
CL_DTYPE4
)(
filter[0].x,
filter[1].x,
filter[2].x,
filter[3].x
)
; // in_ch:0,out_ch:0-3
filter_trans[1]
=
(
CL_DTYPE4
)(
filter[0].y,
filter[1].y,
filter[2].y,
filter[3].y
)
; // in_ch:1,out_ch:0-3
filter_trans[2]
=
(
CL_DTYPE4
)(
filter[0].z,
filter[1].z,
filter[2].z,
filter[3].z
)
; // in_ch:2,out_ch:0-3
filter_trans[3]
=
(
CL_DTYPE4
)(
filter[0].w,
filter[1].w,
filter[2].w,
filter[3].w
)
; // in_ch:3,out_ch:0-3
input[0]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val0,
in_h_val
))
;
input[1]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val1,
in_h_val
))
;
input[2]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val2,
in_h_val
))
;
input[3]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val3,
in_h_val
))
;
input[4]
=
READ_IMG_TYPE
(
CL_DTYPE_CHAR,
input_image,
sampler,
(
int2
)(
in_w_val4,
in_h_val
))
;
output[0]
=
mad
(
input[0].x,
filter_trans[0],
output[0]
)
;
output[1]
=
mad
(
input[1].x,
filter_trans[0],
output[1]
)
;
output[2]
=
mad
(
input[2].x,
filter_trans[0],
output[2]
)
;
output[3]
=
mad
(
input[3].x,
filter_trans[0],
output[3]
)
;
output[4]
=
mad
(
input[4].x,
filter_trans[0],
output[4]
)
;
if
(
ch_surplus
<
3
)
{
output[0]
=
mad
(
input[0].y,
filter_trans[1],
output[0]
)
;
output[1]
=
mad
(
input[1].y,
filter_trans[1],
output[1]
)
;
output[2]
=
mad
(
input[2].y,
filter_trans[1],
output[2]
)
;
output[3]
=
mad
(
input[3].y,
filter_trans[1],
output[3]
)
;
output[4]
=
mad
(
input[4].y,
filter_trans[1],
output[4]
)
;
}
if
(
ch_surplus
<
2
)
{
output[0]
=
mad
(
input[0].z,
filter_trans[2],
output[0]
)
;
output[1]
=
mad
(
input[1].z,
filter_trans[2],
output[1]
)
;
output[2]
=
mad
(
input[2].z,
filter_trans[2],
output[2]
)
;
output[3]
=
mad
(
input[3].z,
filter_trans[2],
output[3]
)
;
output[4]
=
mad
(
input[4].z,
filter_trans[2],
output[4]
)
;
}
if
(
ch_surplus
<
1
)
{
output[0]
=
mad
(
input[0].w,
filter_trans[3],
output[0]
)
;
output[1]
=
mad
(
input[1].w,
filter_trans[3],
output[1]
)
;
output[2]
=
mad
(
input[2].w,
filter_trans[3],
output[2]
)
;
output[3]
=
mad
(
input[3].w,
filter_trans[3],
output[3]
)
;
output[4]
=
mad
(
input[4].w,
filter_trans[3],
output[4]
)
;
}
}
}
}
output[0]
=
activation_type4
(
output[0]
)
;
output[1]
=
activation_type4
(
output[1]
)
;
output[2]
=
activation_type4
(
output[2]
)
;
output[3]
=
activation_type4
(
output[3]
)
;
output[4]
=
activation_type4
(
output[4]
)
;
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id0,
item_h_id
)
,
output[0]
)
;
if
(
out_w_id1
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id1,
item_h_id
)
,
output[1]
)
;
}
if
(
out_w_id2
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id2,
item_h_id
)
,
output[2]
)
;
}
if
(
out_w_id3
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id3,
item_h_id
)
,
output[3]
)
;
}
if
(
out_w_id4
<
out_w
)
{
WRITE_IMG_TYPE
(
CL_DTYPE_CHAR,
output_image,
(
int2
)(
out_w_base_id
+
out_w_id4,
item_h_id
)
,
output[4]
)
;
}
}
\ No newline at end of file
lite/kernels/opencl/conv_image_compute.cc
浏览文件 @
73ca2e00
...
...
@@ -142,9 +142,10 @@ void ConvImageCompute::PrepareForRun() {
filter_image_dims
[
0
],
filter_image_dims
[
1
],
filter_image_v
.
data
());
impl_
=
&
ConvImageCompute
::
DepthwiseConv2d
;
}
else
if
(
kernel_
h
==
3
&&
kernel_h
==
3
)
{
}
else
if
(
kernel_
w
==
3
&&
kernel_h
==
3
)
{
// conv2d_3x3
kernel_func_names_
.
push_back
(
"conv2d_3x3_opt"
);
kernel_func_names_
.
push_back
(
bs
>
1
?
"conv2d_3x3_multi_batch"
:
"conv2d_3x3_opt"
);
kernel_func_paths_
.
push_back
(
"image/conv2d_3x3_opt_kernel.cl"
);
CLImageConverterFolder
converter
;
...
...
@@ -174,7 +175,9 @@ void ConvImageCompute::PrepareForRun() {
impl_
=
&
ConvImageCompute
::
Conv2d5x5
;
#else
// conv2d_5x5_opt
kernel_func_names_
.
push_back
(
"conv2d_5x5_opt"
);
kernel_func_names_
.
push_back
(
bs
>
1
?
"conv2d_5x5_multi_batch"
:
"conv2d_5x5_opt"
);
kernel_func_paths_
.
push_back
(
"image/conv2d_5x5_opt_kernel.cl"
);
CLImageConverterFolder
converter
;
...
...
@@ -207,7 +210,8 @@ void ConvImageCompute::PrepareForRun() {
#else
// conv2d_7x7
kernel_func_names_
.
push_back
(
"conv2d_7x7_opt"
);
kernel_func_names_
.
push_back
(
bs
>
1
?
"conv2d_7x7_multi_batch"
:
"conv2d_7x7_opt"
);
kernel_func_paths_
.
push_back
(
"image/conv2d_7x7_opt_kernel.cl"
);
CLImageConverterFolder
converter
;
...
...
lite/kernels/opencl/conv_image_compute.h
浏览文件 @
73ca2e00
...
...
@@ -59,7 +59,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
std
::
shared_ptr
<
cl
::
Event
>
event_
{
new
cl
::
Event
};
Tensor
filter_gpu_image_
;
Tensor
bias_gpu_image_
;
bool
use_lws
{
fals
e
};
bool
use_lws
{
tru
e
};
};
}
// namespace opencl
...
...
lite/kernels/opencl/conv_image_compute_test.cc
浏览文件 @
73ca2e00
...
...
@@ -510,7 +510,7 @@ TEST(conv2d, compute_image2d_3x3) {
const
int
dilation
=
1
;
const
int
stride
=
2
;
const
int
group
=
1
;
for
(
int
batch_size
=
1
;
batch_size
<
2
;
++
batch_size
)
{
for
(
int
batch_size
=
1
;
batch_size
<
4
;
++
batch_size
)
{
for
(
int
oc
=
1
;
oc
<
10
;
oc
+=
1
)
{
// oc
for
(
int
ih
=
5
;
ih
<
9
;
ih
+=
1
)
{
// ih
int
iw
=
ih
;
...
...
@@ -532,7 +532,7 @@ const int stride = 2;
#else
// big scale with group
const
int
stride
=
1
;
const
int
group
=
32
/
1
;
const
int
batch_size
=
1
;
const
int
batch_size
=
2
;
const
int
ic
=
32
/
1
;
const
int
ih
=
112
/
1
;
const
int
iw
=
112
/
1
;
...
...
@@ -558,7 +558,8 @@ const int stride = 2;
PRECISION
(
kFP16
),
DATALAYOUT
(
kImageDefault
));
ASSERT_FALSE
(
kernels
.
empty
());
CHECK
(
batch_size
==
1
)
<<
"conv3x3 only supprt batch_size == 1"
;
// CHECK(batch_size == 1) << "conv3x3 only supprt
// batch_size == 1";
auto
kernel
=
std
::
move
(
kernels
.
front
());
SHADOW_LOG
<<
"created conv2d kernel"
;
...
...
@@ -886,15 +887,16 @@ TEST(conv2d, compute_image2d_5x5) {
// int loop_cnt = 0;
#ifdef LOOP_TEST
for
(
int
batch_size
=
1
;
batch_size
<
2
;
++
batch_size
)
{
for
(
int
oc
=
1
;
oc
<
10
;
oc
+=
1
)
{
// oc
for
(
int
ih
=
5
;
ih
<
9
;
ih
+=
1
)
{
// ih
for
(
int
batch_size
=
1
;
batch_size
<
4
;
++
batch_size
)
{
for
(
int
oc
=
1
;
oc
<
5
;
oc
+=
1
)
{
// oc
for
(
int
ih
=
5
;
ih
<
8
;
ih
+=
1
)
{
// ih
int
iw
=
ih
;
for
(
int
ic
=
2
;
ic
<
10
;
ic
+=
1
)
{
// ic
for
(
int
ic
=
2
;
ic
<
6
;
ic
+=
1
)
{
// ic
for
(
bool
bias_flag
:
{
true
,
false
})
{
for
(
std
::
string
relu_flag
:
{
/*true,*/
"relu"
})
{
for
(
std
::
string
relu_flag
:
{
""
"relu"
})
{
#else
const
int
batch_size
=
1
;
const
int
batch_size
=
2
;
const
int
oc
=
1
;
const
int
ih
=
5
;
const
int
iw
=
5
;
...
...
@@ -1231,15 +1233,15 @@ TEST(conv2d, compute_image2d_7x7) {
// int loop_cnt = 0;
#ifdef LOOP_TEST
for
(
int
batch_size
=
1
;
batch_size
<
2
;
++
batch_size
)
{
for
(
int
oc
=
1
;
oc
<
10
;
oc
+=
1
)
{
// oc
for
(
int
batch_size
=
1
;
batch_size
<
4
;
++
batch_size
)
{
for
(
int
oc
=
1
;
oc
<
6
;
oc
+=
1
)
{
// oc
for
(
int
ih
=
7
;
ih
<
8
;
ih
+=
1
)
{
// ih
int
iw
=
ih
;
for
(
int
ic
=
2
;
ic
<
4
;
ic
+=
1
)
{
// ic
for
(
bool
bias_flag
:
{
false
,
true
})
{
for
(
std
::
string
relu_flag
:
{
""
,
"relu"
})
{
#else
const
int
batch_size
=
1
;
const
int
batch_size
=
2
;
const
int
oc
=
1
;
const
int
ih
=
7
;
const
int
iw
=
7
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录