diff --git a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp index 3489f44d91ed6c4ecb0af9837f4dfd4a4d8c0d6a..a61952f12f95ea304f9fc98d67bc330a1fb3f631 100644 --- a/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp +++ b/mobile/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp @@ -399,20 +399,24 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); - CL_CHECK_ERRORS(status); - if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) { - if (filter_channel != input_channel) { + // normal conv + if (param.Filter()->dims()[0] == param.Output()->dims()[1] && + param.Filter()->dims()[1] == param.Input()->dims()[1]) { + status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); CL_CHECK_ERRORS(status); - int group = input_channel / filter_channel; + int group = 1; status = clSetKernelArg(kernel, index++, sizeof(int), &group); CL_CHECK_ERRORS(status); - } else { + } else if (!(param.Filter()->dims()[0] == param.Input()->dims()[1] && + param.Filter()->dims()[1] == 1)) { // not depwise + status = clSetKernelArg(kernel, index++, sizeof(int), &output_c); + CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); CL_CHECK_ERRORS(status); - int group = 1; + int group = input_channel / filter_channel; status = clSetKernelArg(kernel, index++, sizeof(int), &group); CL_CHECK_ERRORS(status); }