diff --git a/src/operators/kernel/cl/conv_kernel.cpp b/src/operators/kernel/cl/conv_kernel.cpp index 05cefadce052fb65664cc797c800ec67e43f3a2c..d3c9fa34539c0a4b8568fba298d544bdf3be5300 100644 --- a/src/operators/kernel/cl/conv_kernel.cpp +++ b/src/operators/kernel/cl/conv_kernel.cpp @@ -26,18 +26,6 @@ bool ConvKernel::Init(ConvParam *param) { param->Paddings()[0] == param->Paddings()[1], "need equal"); - auto filter_ddim = param->Filter()->dims(); - - std::vector filter_shape( - {filter_ddim[1], filter_ddim[0], filter_ddim[2], filter_ddim[3]}); - framework::DDim ddim = framework::make_ddim(filter_shape); - if (filter_ddim[1] == 1) { - param->Filter()->Resize(ddim); - } - - param->Filter()->InitCLImage(cl_helper_.CLContext(), - this->cl_helper_.CLCommandQueue()); - int offset = static_cast(param->Filter()->dims()[2]) / 2 - static_cast(param->Paddings()[1]); param->SetOffset(offset); @@ -49,19 +37,25 @@ bool ConvKernel::Init(ConvParam *param) { DLOG << " filter dims: " << param->Filter()->dims(); if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) { - DLOG << " here1 "; + param->Filter()->InitNImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); this->cl_helper_.AddKernel("conv_1x1", "conv_kernel.cl"); + DLOG << "conv 1x1"; - } else if (param->Filter()->dims()[0] == 1 && + } else if (param->Filter()->dims()[1] == 1 && param->Input()->dims()[1] == param->Output()->dims()[1] && param->Filter()->dims()[2] == 3) { - DLOG << " here2 "; + param->Filter()->InitDWImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); this->cl_helper_.AddKernel("depth_conv_3x3", "depthwise_conv_kernel.cl"); + DLOG << "depth_conv 3x3"; } else if (param->Filter()->dims()[2] == 3 && param->Filter()->dims()[3] == 3) { - DLOG << " here3 "; - this->cl_helper_.AddKernel("conv_3x3", "conv_kernel.cl"); + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_3x3", "conv_kernel.cl"); + DLOG << "conv 3x3"; } else { PADDLE_MOBILE_THROW_EXCEPTION(" not support ");