Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
16811a7c
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
16811a7c
编写于
7月 31, 2019
作者:
Y
Yanzhan Yang
提交者:
GitHub
7月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support group conv in opencl (#1776)
* support group conv in opencl * fix style
上级
526c446d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
333 addition
and
222 deletion
+333
-222
src/operators/kernel/cl/cl-kernel-func/conv_func.cpp
src/operators/kernel/cl/cl-kernel-func/conv_func.cpp
+21
-0
src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl
src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl
+312
-222
未找到文件。
src/operators/kernel/cl/cl-kernel-func/conv_func.cpp
浏览文件 @
16811a7c
...
...
@@ -55,6 +55,8 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
int
input_height
=
param
.
Input
()
->
dims
()[
2
];
int
output_width
=
param
.
Output
()
->
dims
()[
3
];
int
output_height
=
param
.
Output
()
->
dims
()[
2
];
int
filter_channel
=
param
.
Filter
()
->
dims
()[
1
];
int
input_channel
=
param
.
Input
()
->
dims
()[
1
];
// DLOG << " c block " << c_block;
// DLOG << " w " << w;
...
...
@@ -205,6 +207,25 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
status
=
clSetKernelArg
(
kernel
,
index
++
,
sizeof
(
int
),
&
output_height
);
CL_CHECK_ERRORS
(
status
);
if
(
param
.
Filter
()
->
dims
()[
2
]
==
3
&&
param
.
Filter
()
->
dims
()[
3
]
==
3
)
{
if
(
filter_channel
!=
input_channel
)
{
if
(
filter_channel
!=
1
)
{
status
=
clSetKernelArg
(
kernel
,
index
++
,
sizeof
(
int
),
&
filter_channel
);
CL_CHECK_ERRORS
(
status
);
int
has_group
=
1
;
status
=
clSetKernelArg
(
kernel
,
index
++
,
sizeof
(
int
),
&
has_group
);
CL_CHECK_ERRORS
(
status
);
}
}
else
{
status
=
clSetKernelArg
(
kernel
,
index
++
,
sizeof
(
int
),
&
filter_channel
);
CL_CHECK_ERRORS
(
status
);
int
has_group
=
0
;
status
=
clSetKernelArg
(
kernel
,
index
++
,
sizeof
(
int
),
&
has_group
);
CL_CHECK_ERRORS
(
status
);
}
}
status
=
clEnqueueNDRangeKernel
(
cl_helper
->
CLCommandQueue
(),
kernel
,
default_work_size
.
size
(),
NULL
,
default_work_size
.
data
(),
NULL
,
0
,
NULL
,
NULL
);
...
...
src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl
浏览文件 @
16811a7c
...
...
@@ -47,7 +47,9 @@ __kernel void conv_3x3(__private const int global_size_dim0,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
__private const int output_height,
__private const int filter_channel,
__private const int has_group) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
...
...
@@ -88,7 +90,7 @@ __kernel void conv_3x3(__private const int global_size_dim0,
#endif
half4 input[9];
if (has_group == 0) {
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
input[0] = select(read_imageh(input_image, sampler,
...
...
@@ -322,6 +324,94 @@ __kernel void conv_3x3(__private const int global_size_dim0,
output.w += dot(input[j], weight_w);
}
} else {
for (int i = 0; i < 4; i++) {
int used_input_channel_num = (out_c * 4 + i) * filter_channel;
for (int f_c = 0; f_c < filter_channel; ++f_c) {
int input_c = used_input_channel_num + f_c;
int input_block = input_c / 4;
int2 pos_in = (int2)(input_block * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
input[0] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15));
input[1] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15));
input[2] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15));
input[3] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15));
input[4] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15));
input[5] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15));
input[6] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15));
input[7] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15));
input[8] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15));
half tmp_out = 0;
for (int j = 0; j < 9; j++) {
int2 pos_of_weight;
pos_of_weight.x = (f_c / 4) * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + i * 3 + j / 3;
half4 weight = read_imageh(filter, sampler, pos_of_weight);
int f_c_offset = f_c % 4;
half f_value;
if (f_c_offset == 0) {
f_value = weight.x;
} else if (f_c_offset == 1) {
f_value = weight.y;
} else if (f_c_offset == 2) {
f_value = weight.z;
} else if (f_c_offset == 3) {
f_value = weight.w;
}
int input_c_offset = input_c % 4;
half input_value;
if (input_c_offset == 0) {
input_value = input[j].x;
} else if (input_c_offset == 1) {
input_value = input[j].y;
} else if (input_c_offset == 2) {
input_value = input[j].z;
} else if (input_c_offset == 3) {
input_value = input[j].w;
}
tmp_out += f_value * input_value;
}
if (i == 0) {
output.x += tmp_out;
} else if (i == 1) {
output.y += tmp_out;
} else if (i == 2) {
output.z += tmp_out;
} else if (i == 3) {
output.w += tmp_out;
}
}
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录