未验证 提交 16811a7c 编写于 作者: Y Yanzhan Yang 提交者: GitHub

support group conv in opencl (#1776)

* support group conv in opencl

* fix style
上级 526c446d
......@@ -55,6 +55,8 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
int filter_channel = param.Filter()->dims()[1];
int input_channel = param.Input()->dims()[1];
// DLOG << " c block " << c_block;
// DLOG << " w " << w;
......@@ -205,6 +207,25 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper,
status = clSetKernelArg(kernel, index++, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) {
if (filter_channel != input_channel) {
if (filter_channel != 1) {
status =
clSetKernelArg(kernel, index++, sizeof(int), &filter_channel);
CL_CHECK_ERRORS(status);
int has_group = 1;
status = clSetKernelArg(kernel, index++, sizeof(int), &has_group);
CL_CHECK_ERRORS(status);
}
} else {
status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel);
CL_CHECK_ERRORS(status);
int has_group = 0;
status = clSetKernelArg(kernel, index++, sizeof(int), &has_group);
CL_CHECK_ERRORS(status);
}
}
status = clEnqueueNDRangeKernel(
cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册