From 2667a153101c7ebff30e718563563580d2c584da Mon Sep 17 00:00:00 2001 From: StarryRain <36948762+StarryRain@users.noreply.github.com> Date: Thu, 17 Oct 2019 18:12:52 +0800 Subject: [PATCH] =?UTF-8?q?fix=20=E2=80=9CCL=5FINVALID=5FKERNEL=5FARGS=20?= =?UTF-8?q?=E2=80=9D=20error=EF=BC=8C=20test=3Ddevelop=20(#2213)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernel/cl/cl-kernel-func/conv_func.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp index 3489f44d91..a61952f12f 100644 --- a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp +++ b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp @@ -399,20 +399,24 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); - CL_CHECK_ERRORS(status); - if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) { - if (filter_channel != input_channel) { + // normal conv + if (param.Filter()->dims()[0] == param.Output()->dims()[1] && + param.Filter()->dims()[1] == param.Input()->dims()[1]) { + status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); CL_CHECK_ERRORS(status); - int group = input_channel / filter_channel; + int group = 1; status = clSetKernelArg(kernel, index++, sizeof(int), &group); CL_CHECK_ERRORS(status); - } else { + } else if (!(param.Filter()->dims()[0] == param.Input()->dims()[1] && + param.Filter()->dims()[1] == 1)) { // not depwise + status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); CL_CHECK_ERRORS(status); - int group = 1; + int group = input_channel / filter_channel; status = clSetKernelArg(kernel, index++, sizeof(int), &group); CL_CHECK_ERRORS(status); } -- GitLab